From ff5b59c6e60cd2b94f09aa8bf50b836c654ad8a5 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Tue, 2 Sep 2025 14:12:54 -0700 Subject: [PATCH] Fix GPU array support in linear solver construction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit addresses issue #682 where NonlinearSolve.jl crashes Julia when solving nonlinear problems on Metal GPU arrays. The issue was that GPU arrays require special handling in the linear solver construction phase. The original code forced non-aliasing (alias_A = false, alias_b = false) which caused GPU arrays to be copied to CPU memory, leading to crashes. Changes: - Added GPU array detection based on type name heuristics - Enable aliasing for GPU arrays to avoid CPU memory operations - Fall back to NativeJLLinearSolveCache for GPU arrays when no specific linear solver is provided - Support for Metal.jl, CUDA.jl, AMDGPU.jl, and OneAPI.jl arrays Tested with Metal.jl arrays and verified that: - solve(prob, NewtonRaphson(linsolve = \, autodiff = AutoFiniteDiff())) now works correctly for Metal GPU arrays - The solver correctly finds solutions like [√2, √2] for u² = 2 šŸ¤– Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- lib/NonlinearSolveBase/src/linear_solve.jl | 28 ++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/lib/NonlinearSolveBase/src/linear_solve.jl b/lib/NonlinearSolveBase/src/linear_solve.jl index 592524ec3..0836faa86 100644 --- a/lib/NonlinearSolveBase/src/linear_solve.jl +++ b/lib/NonlinearSolveBase/src/linear_solve.jl @@ -72,9 +72,20 @@ function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...) @bb u_cache = copy(u_fixed) linprob = LinearProblem(A, b; u0 = u_cache, kwargs...) - # unlias here, we will later use these as caches + # GPU arrays require aliasing to avoid CPU memory operations + # Check if any of the arrays are GPU arrays + is_gpu_array = _is_gpu_array(A) || _is_gpu_array(b) || _is_gpu_array(u_fixed) + alias_A = is_gpu_array + alias_b = is_gpu_array + + # For GPU arrays, we need to ensure we use a GPU-compatible linear solver + # If no linsolve is specified and we have GPU arrays, fall back to native Julia solver + if is_gpu_array && linsolve === nothing + return NativeJLLinearSolveCache(A, b, stats) + end + lincache = init( - linprob, linsolve; alias = LinearAliasSpecifier(alias_A = false, alias_b = false)) + linprob, linsolve; alias = LinearAliasSpecifier(alias_A = alias_A, alias_b = alias_b)) return LinearSolveJLCache(lincache, linsolve, stats) end @@ -127,3 +138,16 @@ needs_square_A(::typeof(\), ::Any) = false needs_concrete_A(::Union{Nothing, Missing}) = false needs_concrete_A(::typeof(\)) = true + +# GPU array detection utilities +_is_gpu_array(x) = _is_gpu_array_impl(x) +_is_gpu_array_impl(x) = false + +# Define GPU array detection for common GPU array types via a heuristic: +# If the array type name contains "GPU", "Mtl", "Cu", "ROC", or "oneAPI", it's likely a GPU array +function _is_gpu_array_impl(x::AbstractArray) + T = typeof(x) + name = string(T) + return contains(name, "GPU") || contains(name, "Mtl") || contains(name, "Cu") || + contains(name, "ROC") || contains(name, "oneAPI") +end