Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions lib/NonlinearSolveBase/src/linear_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,20 @@ function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...)
@bb u_cache = copy(u_fixed)
linprob = LinearProblem(A, b; u0 = u_cache, kwargs...)

# unlias here, we will later use these as caches
# GPU arrays require aliasing to avoid CPU memory operations
# Check if any of the arrays are GPU arrays
is_gpu_array = _is_gpu_array(A) || _is_gpu_array(b) || _is_gpu_array(u_fixed)
alias_A = is_gpu_array
alias_b = is_gpu_array

# For GPU arrays, we need to ensure we use a GPU-compatible linear solver
# If no linsolve is specified and we have GPU arrays, fall back to native Julia solver
if is_gpu_array && linsolve === nothing
return NativeJLLinearSolveCache(A, b, stats)
end

lincache = init(
linprob, linsolve; alias = LinearAliasSpecifier(alias_A = false, alias_b = false))
linprob, linsolve; alias = LinearAliasSpecifier(alias_A = alias_A, alias_b = alias_b))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
linprob, linsolve; alias = LinearAliasSpecifier(alias_A = alias_A, alias_b = alias_b))
linprob, linsolve; alias = LinearAliasSpecifier(
alias_A = alias_A, alias_b = alias_b))

return LinearSolveJLCache(lincache, linsolve, stats)
end

Expand Down Expand Up @@ -127,3 +138,16 @@ needs_square_A(::typeof(\), ::Any) = false

needs_concrete_A(::Union{Nothing, Missing}) = false
needs_concrete_A(::typeof(\)) = true

# GPU array detection utilities
_is_gpu_array(x) = _is_gpu_array_impl(x)
_is_gpu_array_impl(x) = false

# Define GPU array detection for common GPU array types via a heuristic:
# If the array type name contains "GPU", "Mtl", "Cu", "ROC", or "oneAPI", it's likely a GPU array
function _is_gpu_array_impl(x::AbstractArray)
T = typeof(x)
name = string(T)
return contains(name, "GPU") || contains(name, "Mtl") || contains(name, "Cu") ||
contains(name, "ROC") || contains(name, "oneAPI")
end
Loading