Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ end
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization;
kwargs...)
if cache.isfresh
cacheval = LinearSolve.@get_cacheval(cache, :CudaOffloadLUFactorization)
fact = lu(CUDA.CuArray(cache.A))
cache.cacheval = fact
cache.isfresh = false
end
y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b)))
fact = LinearSolve.@get_cacheval(cache, :CudaOffloadLUFactorization)
y = Array(ldiv!(CUDA.CuArray(cache.u), fact, CUDA.CuArray(cache.b)))
cache.u .= y
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end
Expand Down
23 changes: 21 additions & 2 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ EnumX.@enumx DefaultAlgorithmChoice begin
QRFactorizationPivoted
KrylovJL_CRAIGMR
KrylovJL_LSMR
BLISLUFactorization
CudaOffloadLUFactorization
MetalLUFactorization
end

# Autotune preference constants - loaded once at package import time
Expand All @@ -296,12 +299,17 @@ function is_algorithm_available(alg::DefaultAlgorithmChoice.T)
return appleaccelerate_isavailable() # Available on macOS with Accelerate
elseif alg === DefaultAlgorithmChoice.RFLUFactorization
return userecursivefactorization(nothing) # Requires RecursiveFactorization extension
elseif alg === DefaultAlgorithmChoice.BLISLUFactorization
return useblis() # Available if BLIS extension is loaded
elseif alg === DefaultAlgorithmChoice.CudaOffloadLUFactorization
return usecuda() # Available if CUDA extension is loaded
elseif alg === DefaultAlgorithmChoice.MetalLUFactorization
return usemetal() # Available if Metal extension is loaded
else
# For extension-dependent algorithms not explicitly handled above,
# we cannot easily check availability without trying to use them.
# For now, assume they're not available in the default selection.
# This includes FastLU, BLIS, CUDA, Metal, etc. which would require
# specific extension checks.
# This includes other extensions that might be added in the future.
return false
end
end
Expand Down Expand Up @@ -399,6 +407,17 @@ isopenblas() = IS_OPENBLAS[]
const HAS_APPLE_ACCELERATE = Ref(false)
appleaccelerate_isavailable() = HAS_APPLE_ACCELERATE[]

# Extension availability checking functions
useblis() = Base.get_extension(@__MODULE__, :LinearSolveBLISExt) !== nothing
usecuda() = Base.get_extension(@__MODULE__, :LinearSolveCUDAExt) !== nothing

# Metal is only available on Apple platforms
@static if !Sys.isapple()
usemetal() = false
else
usemetal() = Base.get_extension(@__MODULE__, :LinearSolveMetalExt) !== nothing
end

PrecompileTools.@compile_workload begin
A = rand(4, 4)
b = rand(4)
Expand Down
71 changes: 70 additions & 1 deletion src/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
needs_concrete_A(alg::DefaultLinearSolver) = true
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17, T18, T19, T20, T21}
T13, T14, T15, T16, T17, T18, T19, T20, T21, T22, T23, T24}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -22,6 +22,9 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
QRFactorizationPivoted::T19
KrylovJL_CRAIGMR::T20
KrylovJL_LSMR::T21
BLISLUFactorization::T22
CudaOffloadLUFactorization::T23
MetalLUFactorization::T24
end

@generated function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v)
Expand Down Expand Up @@ -422,6 +425,12 @@ function algchoice_to_alg(alg::Symbol)
KrylovJL_CRAIGMR()
elseif alg === :KrylovJL_LSMR
KrylovJL_LSMR()
elseif alg === :BLISLUFactorization
BLISLUFactorization(throwerror = false)
elseif alg === :CudaOffloadLUFactorization
CudaOffloadLUFactorization(throwerror = false)
elseif alg === :MetalLUFactorization
MetalLUFactorization(throwerror = false)
else
error("Algorithm choice symbol $alg not allowed in the default")
end
Expand Down Expand Up @@ -526,6 +535,66 @@ end
error("Default algorithm calling solve on RecursiveFactorization without the package being loaded. This shouldn't happen.")
end

sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
## TODO: Add verbosity logging here about using the fallback
sol = SciMLBase.solve!(
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
else
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
end
end
elseif alg == Symbol(DefaultAlgorithmChoice.BLISLUFactorization)
newex = quote
if !useblis()
error("Default algorithm calling solve on BLISLUFactorization without the extension being loaded. This shouldn't happen.")
end

sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
## TODO: Add verbosity logging here about using the fallback
sol = SciMLBase.solve!(
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
else
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
end
end
elseif alg == Symbol(DefaultAlgorithmChoice.CudaOffloadLUFactorization)
newex = quote
if !usecuda()
error("Default algorithm calling solve on CudaOffloadLUFactorization without CUDA.jl being loaded. This shouldn't happen.")
end

sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
## TODO: Add verbosity logging here about using the fallback
sol = SciMLBase.solve!(
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
else
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
end
end
elseif alg == Symbol(DefaultAlgorithmChoice.MetalLUFactorization)
newex = quote
if !usemetal()
error("Default algorithm calling solve on MetalLUFactorization without Metal.jl being loaded. This shouldn't happen.")
end

sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
## TODO: Add verbosity logging here about using the fallback
Expand Down
76 changes: 69 additions & 7 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ Requires a sufficiently large `A` to overcome the data transfer costs.
Using this solver requires adding the package CUDA.jl, i.e. `using CUDA`
"""
struct CudaOffloadLUFactorization <: AbstractFactorization
function CudaOffloadLUFactorization()
function CudaOffloadLUFactorization(; throwerror = true)
ext = Base.get_extension(@__MODULE__, :LinearSolveCUDAExt)
if ext === nothing
if ext === nothing && throwerror
error("CudaOffloadLUFactorization requires that CUDA is loaded, i.e. `using CUDA`")
else
return new()
Expand Down Expand Up @@ -610,16 +610,78 @@ A wrapper over the IterativeSolvers.jl MINRES.
function IterativeSolversJL_MINRES end

"""
MetalLUFactorization()

A wrapper over Apple's Metal GPU library for LU factorization. Direct calls to Metal
in a way that pre-allocates workspace to avoid allocations and automatically offloads
to the GPU. This solver is optimized for Metal-capable Apple Silicon Macs.

## Requirements
Using this solver requires that Metal.jl is loaded: `using Metal`

## Performance Notes
- Most efficient for large dense matrices where GPU acceleration benefits outweigh transfer costs
- Automatically manages GPU memory and transfers
- Particularly effective on Apple Silicon Macs with unified memory

## Example
```julia
MetalLUFactorization()
using Metal
alg = MetalLUFactorization()
sol = solve(prob, alg)
```
"""
struct MetalLUFactorization <: AbstractFactorization
function MetalLUFactorization(; throwerror = true)
@static if !Sys.isapple()
if throwerror
error("MetalLUFactorization is only available on Apple platforms")
else
return new()
end
else
ext = Base.get_extension(@__MODULE__, :LinearSolveMetalExt)
if ext === nothing && throwerror
error("MetalLUFactorization requires that Metal.jl is loaded, i.e. `using Metal`")
else
return new()
end
end
end
end

A wrapper over Apple's Metal GPU library. Direct calls to Metal in a way that pre-allocates workspace
to avoid allocations and automatically offloads to the GPU.
"""
struct MetalLUFactorization <: AbstractFactorization end
BLISLUFactorization()

An LU factorization implementation using the BLIS (BLAS-like Library Instantiation Software)
framework. BLIS provides high-performance dense linear algebra kernels optimized for various
CPU architectures.

struct BLISLUFactorization <: AbstractFactorization end
## Requirements
Using this solver requires that blis_jll is available and the BLIS extension is loaded.
The solver will be automatically available when conditions are met.

## Performance Notes
- Optimized for modern CPU architectures with BLIS-specific optimizations
- May provide better performance than standard BLAS on certain processors
- Best suited for dense matrices with Float32, Float64, ComplexF32, or ComplexF64 elements

## Example
```julia
alg = BLISLUFactorization()
sol = solve(prob, alg)
```
"""
struct BLISLUFactorization <: AbstractFactorization
function BLISLUFactorization(; throwerror = true)
ext = Base.get_extension(@__MODULE__, :LinearSolveBLISExt)
if ext === nothing && throwerror
error("BLISLUFactorization requires that the BLIS extension is loaded and blis_jll is available")
else
return new()
end
end
end

"""
`CUSOLVERRFFactorization(; symbolic = :RF, reuse_symbolic = true)`
Expand Down
17 changes: 17 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,23 @@ function init_cacheval(::CliqueTreesFactorization, ::StaticArray, b, u, Pl, Pr,
nothing
end

# Fallback init_cacheval for extension-based algorithms when extensions aren't loaded
# These return nothing since the actual implementations are in the extensions
function init_cacheval(::BLISLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function init_cacheval(::CudaOffloadLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function init_cacheval(::MetalLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
InteractiveUtils.subtypes(AbstractSparseFactorization))
@eval function init_cacheval(alg::$alg, A::MatrixOperator, b, u, Pl, Pr,
Expand Down
6 changes: 3 additions & 3 deletions src/preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ function _string_to_algorithm_choice(algorithm_name::Union{String, Nothing})
elseif algorithm_name == "FastLUFactorization"
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (FastLapack extension)
elseif algorithm_name == "BLISLUFactorization"
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (BLIS extension)
return DefaultAlgorithmChoice.BLISLUFactorization # Now supported as a separate choice
elseif algorithm_name == "CudaOffloadLUFactorization"
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (CUDA extension)
return DefaultAlgorithmChoice.CudaOffloadLUFactorization # Now supported as a separate choice
elseif algorithm_name == "MetalLUFactorization"
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (Metal extension)
return DefaultAlgorithmChoice.MetalLUFactorization # Now supported as a separate choice
elseif algorithm_name == "AMDGPUOffloadLUFactorization"
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (AMDGPU extension)
else
Expand Down
3 changes: 2 additions & 1 deletion test/nopre/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ end

# CUDA/Metal factorizations (only test if CUDA/Metal are loaded)
# CudaOffloadFactorization requires CUDA to be loaded, skip if not available
if @isdefined(MetalLUFactorization)
# Metal is only available on Apple platforms
if Sys.isapple() && @isdefined(MetalLUFactorization)
JET.@test_opt solve(prob, MetalLUFactorization()) broken=true
end
if @isdefined(BLISLUFactorization)
Expand Down
86 changes: 86 additions & 0 deletions test_new_solvers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using Pkg
Pkg.activate(".")
using LinearSolve
using Test
using LinearAlgebra

# Test that the new algorithm choices are available in the enum
@testset "New Algorithm Choices" begin
choices = Symbol.(instances(LinearSolve.DefaultAlgorithmChoice.T))
println("Available choices: ", choices)
@test :BLISLUFactorization in choices
@test :CudaOffloadLUFactorization in choices
@test :MetalLUFactorization in choices
end

# Test that availability checking functions exist
@testset "Availability Functions" begin
# These should return false since the extensions aren't loaded
@test LinearSolve.useblis() == false
@test LinearSolve.usecuda() == false
@test LinearSolve.usemetal() == false

# Test that is_algorithm_available correctly reports availability
@test LinearSolve.is_algorithm_available(LinearSolve.DefaultAlgorithmChoice.BLISLUFactorization) == false
@test LinearSolve.is_algorithm_available(LinearSolve.DefaultAlgorithmChoice.CudaOffloadLUFactorization) == false
@test LinearSolve.is_algorithm_available(LinearSolve.DefaultAlgorithmChoice.MetalLUFactorization) == false
end

# Test that the algorithms can be instantiated without extensions (with throwerror=false)
@testset "Algorithm Instantiation" begin
# These should work with throwerror=false
alg1 = LinearSolve.BLISLUFactorization(throwerror=false)
@test alg1 isa LinearSolve.BLISLUFactorization

alg2 = LinearSolve.CudaOffloadLUFactorization(throwerror=false)
@test alg2 isa LinearSolve.CudaOffloadLUFactorization

# Metal is only available on Apple platforms
if Sys.isapple()
alg3 = LinearSolve.MetalLUFactorization(throwerror=false)
@test alg3 isa LinearSolve.MetalLUFactorization
else
# On non-Apple platforms, it should still not error with throwerror=false
alg3 = LinearSolve.MetalLUFactorization(throwerror=false)
@test alg3 isa LinearSolve.MetalLUFactorization
end

# These should throw errors with throwerror=true (default)
@test_throws ErrorException LinearSolve.BLISLUFactorization()
@test_throws ErrorException LinearSolve.CudaOffloadLUFactorization()

# Metal error message depends on platform
if Sys.isapple()
@test_throws ErrorException LinearSolve.MetalLUFactorization()
else
# On non-Apple platforms, should error with platform message
@test_throws ErrorException LinearSolve.MetalLUFactorization()
end
end

# Test that preferences system recognizes the new algorithms
@testset "Preferences Support" begin
# Test that the preference string mapping works
alg = LinearSolve._string_to_algorithm_choice("BLISLUFactorization")
@test alg === LinearSolve.DefaultAlgorithmChoice.BLISLUFactorization

alg = LinearSolve._string_to_algorithm_choice("CudaOffloadLUFactorization")
@test alg === LinearSolve.DefaultAlgorithmChoice.CudaOffloadLUFactorization

alg = LinearSolve._string_to_algorithm_choice("MetalLUFactorization")
@test alg === LinearSolve.DefaultAlgorithmChoice.MetalLUFactorization
end

# Test basic solve still works with DefaultLinearSolver
@testset "Default Solver Still Works" begin
A = rand(10, 10)
b = rand(10)
prob = LinearProblem(A, b)

# Should use default solver and work fine
sol = solve(prob)
@test sol.retcode == ReturnCode.Success
@test norm(A * sol.u - b) < 1e-10
end

println("All tests passed!")
Loading