diff --git a/lib/NonlinearSolveBase/src/linear_solve.jl b/lib/NonlinearSolveBase/src/linear_solve.jl index 592524ec3..0836faa86 100644 --- a/lib/NonlinearSolveBase/src/linear_solve.jl +++ b/lib/NonlinearSolveBase/src/linear_solve.jl @@ -72,9 +72,20 @@ function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...) @bb u_cache = copy(u_fixed) linprob = LinearProblem(A, b; u0 = u_cache, kwargs...) - # unlias here, we will later use these as caches + # GPU arrays require aliasing to avoid CPU memory operations + # Check if any of the arrays are GPU arrays + is_gpu_array = _is_gpu_array(A) || _is_gpu_array(b) || _is_gpu_array(u_fixed) + alias_A = is_gpu_array + alias_b = is_gpu_array + + # For GPU arrays, we need to ensure we use a GPU-compatible linear solver + # If no linsolve is specified and we have GPU arrays, fall back to native Julia solver + if is_gpu_array && linsolve === nothing + return NativeJLLinearSolveCache(A, b, stats) + end + lincache = init( - linprob, linsolve; alias = LinearAliasSpecifier(alias_A = false, alias_b = false)) + linprob, linsolve; alias = LinearAliasSpecifier(alias_A = alias_A, alias_b = alias_b)) return LinearSolveJLCache(lincache, linsolve, stats) end @@ -127,3 +138,16 @@ needs_square_A(::typeof(\), ::Any) = false needs_concrete_A(::Union{Nothing, Missing}) = false needs_concrete_A(::typeof(\)) = true + +# GPU array detection utilities +_is_gpu_array(x) = _is_gpu_array_impl(x) +_is_gpu_array_impl(x) = false + +# Define GPU array detection for common GPU array types via a heuristic: +# If the array type name contains "GPU", "Mtl", "Cu", "ROC", or "oneAPI", it's likely a GPU array +function _is_gpu_array_impl(x::AbstractArray) + T = typeof(x) + name = string(T) + return contains(name, "GPU") || contains(name, "Mtl") || contains(name, "Cu") || + contains(name, "ROC") || contains(name, "oneAPI") +end