Skip to content

Commit e205c85

Browse files
ChrisRackauckas-ClaudeChrisRackauckasclaude
authored
Add BLISLUFactorization, CudaOffloadLUFactorization, and MetalLUFactorization to default solver choices (#733)
* Add BLISLUFactorization, CudaOffloadLUFactorization, and MetalLUFactorization to default solver choices - Added new algorithm choices to DefaultAlgorithmChoice enum - Implemented conditional availability checking for new solvers - Added throwerror parameter to constructors for compatibility with default solver - Added fallback init_cacheval implementations for when extensions aren't loaded - Updated preferences system to recognize new algorithm names - Added availability checking functions (useblis, usecuda, usemetal) - Updated DefaultLinearSolverInit struct to include new algorithms - Added handling in solve! function for new algorithms with proper extension checks These solvers will only be selected by the default algorithm if: 1. They are available (extensions loaded) 2. They are specified in preferences from autotuning Modeled implementation after RFLUFactorization pattern with conditional availability. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Fix CudaOffloadLUFactorization to use @get_cacheval macro When algorithms are part of the default solver system, they must use the @get_cacheval macro to properly retrieve cached values from the unified cache structure. Updated CudaOffloadLUFactorization to follow this pattern. BLISLUFactorization and MetalLUFactorization were already using the correct pattern. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Make Metal availability check static and platform-aware - Updated usemetal() to be a static check that returns false on non-Apple platforms - Modified MetalLUFactorization constructor to check platform with @static - Updated test files to skip Metal tests on non-Apple platforms - This fixes CI failures on Linux where Metal is not available Following the same pattern as AppleAccelerateLUFactorization for consistency. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> --------- Co-authored-by: ChrisRackauckas <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent 2024295 commit e205c85

File tree

8 files changed

+271
-15
lines changed

8 files changed

+271
-15
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ end
3838
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization;
3939
kwargs...)
4040
if cache.isfresh
41+
cacheval = LinearSolve.@get_cacheval(cache, :CudaOffloadLUFactorization)
4142
fact = lu(CUDA.CuArray(cache.A))
4243
cache.cacheval = fact
4344
cache.isfresh = false
4445
end
45-
y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b)))
46+
fact = LinearSolve.@get_cacheval(cache, :CudaOffloadLUFactorization)
47+
y = Array(ldiv!(CUDA.CuArray(cache.u), fact, CUDA.CuArray(cache.b)))
4648
cache.u .= y
4749
SciMLBase.build_linear_solution(alg, y, nothing, cache)
4850
end

src/LinearSolve.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ EnumX.@enumx DefaultAlgorithmChoice begin
275275
QRFactorizationPivoted
276276
KrylovJL_CRAIGMR
277277
KrylovJL_LSMR
278+
BLISLUFactorization
279+
CudaOffloadLUFactorization
280+
MetalLUFactorization
278281
end
279282

280283
# Autotune preference constants - loaded once at package import time
@@ -296,12 +299,17 @@ function is_algorithm_available(alg::DefaultAlgorithmChoice.T)
296299
return appleaccelerate_isavailable() # Available on macOS with Accelerate
297300
elseif alg === DefaultAlgorithmChoice.RFLUFactorization
298301
return userecursivefactorization(nothing) # Requires RecursiveFactorization extension
302+
elseif alg === DefaultAlgorithmChoice.BLISLUFactorization
303+
return useblis() # Available if BLIS extension is loaded
304+
elseif alg === DefaultAlgorithmChoice.CudaOffloadLUFactorization
305+
return usecuda() # Available if CUDA extension is loaded
306+
elseif alg === DefaultAlgorithmChoice.MetalLUFactorization
307+
return usemetal() # Available if Metal extension is loaded
299308
else
300309
# For extension-dependent algorithms not explicitly handled above,
301310
# we cannot easily check availability without trying to use them.
302311
# For now, assume they're not available in the default selection.
303-
# This includes FastLU, BLIS, CUDA, Metal, etc. which would require
304-
# specific extension checks.
312+
# This includes other extensions that might be added in the future.
305313
return false
306314
end
307315
end
@@ -399,6 +407,17 @@ isopenblas() = IS_OPENBLAS[]
399407
const HAS_APPLE_ACCELERATE = Ref(false)
400408
appleaccelerate_isavailable() = HAS_APPLE_ACCELERATE[]
401409

410+
# Extension availability checking functions
411+
useblis() = Base.get_extension(@__MODULE__, :LinearSolveBLISExt) !== nothing
412+
usecuda() = Base.get_extension(@__MODULE__, :LinearSolveCUDAExt) !== nothing
413+
414+
# Metal is only available on Apple platforms
415+
@static if !Sys.isapple()
416+
usemetal() = false
417+
else
418+
usemetal() = Base.get_extension(@__MODULE__, :LinearSolveMetalExt) !== nothing
419+
end
420+
402421
PrecompileTools.@compile_workload begin
403422
A = rand(4, 4)
404423
b = rand(4)

src/default.jl

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
needs_concrete_A(alg::DefaultLinearSolver) = true
22
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3-
T13, T14, T15, T16, T17, T18, T19, T20, T21}
3+
T13, T14, T15, T16, T17, T18, T19, T20, T21, T22, T23, T24}
44
LUFactorization::T1
55
QRFactorization::T2
66
DiagonalFactorization::T3
@@ -22,6 +22,9 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
2222
QRFactorizationPivoted::T19
2323
KrylovJL_CRAIGMR::T20
2424
KrylovJL_LSMR::T21
25+
BLISLUFactorization::T22
26+
CudaOffloadLUFactorization::T23
27+
MetalLUFactorization::T24
2528
end
2629

2730
@generated function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v)
@@ -422,6 +425,12 @@ function algchoice_to_alg(alg::Symbol)
422425
KrylovJL_CRAIGMR()
423426
elseif alg === :KrylovJL_LSMR
424427
KrylovJL_LSMR()
428+
elseif alg === :BLISLUFactorization
429+
BLISLUFactorization(throwerror = false)
430+
elseif alg === :CudaOffloadLUFactorization
431+
CudaOffloadLUFactorization(throwerror = false)
432+
elseif alg === :MetalLUFactorization
433+
MetalLUFactorization(throwerror = false)
425434
else
426435
error("Algorithm choice symbol $alg not allowed in the default")
427436
end
@@ -526,6 +535,66 @@ end
526535
error("Default algorithm calling solve on RecursiveFactorization without the package being loaded. This shouldn't happen.")
527536
end
528537

538+
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
539+
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
540+
## TODO: Add verbosity logging here about using the fallback
541+
sol = SciMLBase.solve!(
542+
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
543+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
544+
retcode = sol.retcode,
545+
iters = sol.iters, stats = sol.stats)
546+
else
547+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
548+
retcode = sol.retcode,
549+
iters = sol.iters, stats = sol.stats)
550+
end
551+
end
552+
elseif alg == Symbol(DefaultAlgorithmChoice.BLISLUFactorization)
553+
newex = quote
554+
if !useblis()
555+
error("Default algorithm calling solve on BLISLUFactorization without the extension being loaded. This shouldn't happen.")
556+
end
557+
558+
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
559+
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
560+
## TODO: Add verbosity logging here about using the fallback
561+
sol = SciMLBase.solve!(
562+
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
563+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
564+
retcode = sol.retcode,
565+
iters = sol.iters, stats = sol.stats)
566+
else
567+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
568+
retcode = sol.retcode,
569+
iters = sol.iters, stats = sol.stats)
570+
end
571+
end
572+
elseif alg == Symbol(DefaultAlgorithmChoice.CudaOffloadLUFactorization)
573+
newex = quote
574+
if !usecuda()
575+
error("Default algorithm calling solve on CudaOffloadLUFactorization without CUDA.jl being loaded. This shouldn't happen.")
576+
end
577+
578+
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
579+
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
580+
## TODO: Add verbosity logging here about using the fallback
581+
sol = SciMLBase.solve!(
582+
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
583+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
584+
retcode = sol.retcode,
585+
iters = sol.iters, stats = sol.stats)
586+
else
587+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
588+
retcode = sol.retcode,
589+
iters = sol.iters, stats = sol.stats)
590+
end
591+
end
592+
elseif alg == Symbol(DefaultAlgorithmChoice.MetalLUFactorization)
593+
newex = quote
594+
if !usemetal()
595+
error("Default algorithm calling solve on MetalLUFactorization without Metal.jl being loaded. This shouldn't happen.")
596+
end
597+
529598
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
530599
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
531600
## TODO: Add verbosity logging here about using the fallback

src/extension_algs.jl

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ Requires a sufficiently large `A` to overcome the data transfer costs.
7373
Using this solver requires adding the package CUDA.jl, i.e. `using CUDA`
7474
"""
7575
struct CudaOffloadLUFactorization <: AbstractFactorization
76-
function CudaOffloadLUFactorization()
76+
function CudaOffloadLUFactorization(; throwerror = true)
7777
ext = Base.get_extension(@__MODULE__, :LinearSolveCUDAExt)
78-
if ext === nothing
78+
if ext === nothing && throwerror
7979
error("CudaOffloadLUFactorization requires that CUDA is loaded, i.e. `using CUDA`")
8080
else
8181
return new()
@@ -610,16 +610,78 @@ A wrapper over the IterativeSolvers.jl MINRES.
610610
function IterativeSolversJL_MINRES end
611611

612612
"""
613+
MetalLUFactorization()
614+
615+
A wrapper over Apple's Metal GPU library for LU factorization. Direct calls to Metal
616+
in a way that pre-allocates workspace to avoid allocations and automatically offloads
617+
to the GPU. This solver is optimized for Metal-capable Apple Silicon Macs.
618+
619+
## Requirements
620+
Using this solver requires that Metal.jl is loaded: `using Metal`
621+
622+
## Performance Notes
623+
- Most efficient for large dense matrices where GPU acceleration benefits outweigh transfer costs
624+
- Automatically manages GPU memory and transfers
625+
- Particularly effective on Apple Silicon Macs with unified memory
626+
627+
## Example
613628
```julia
614-
MetalLUFactorization()
629+
using Metal
630+
alg = MetalLUFactorization()
631+
sol = solve(prob, alg)
615632
```
633+
"""
634+
struct MetalLUFactorization <: AbstractFactorization
635+
function MetalLUFactorization(; throwerror = true)
636+
@static if !Sys.isapple()
637+
if throwerror
638+
error("MetalLUFactorization is only available on Apple platforms")
639+
else
640+
return new()
641+
end
642+
else
643+
ext = Base.get_extension(@__MODULE__, :LinearSolveMetalExt)
644+
if ext === nothing && throwerror
645+
error("MetalLUFactorization requires that Metal.jl is loaded, i.e. `using Metal`")
646+
else
647+
return new()
648+
end
649+
end
650+
end
651+
end
616652

617-
A wrapper over Apple's Metal GPU library. Direct calls to Metal in a way that pre-allocates workspace
618-
to avoid allocations and automatically offloads to the GPU.
619653
"""
620-
struct MetalLUFactorization <: AbstractFactorization end
654+
BLISLUFactorization()
655+
656+
An LU factorization implementation using the BLIS (BLAS-like Library Instantiation Software)
657+
framework. BLIS provides high-performance dense linear algebra kernels optimized for various
658+
CPU architectures.
621659
622-
struct BLISLUFactorization <: AbstractFactorization end
660+
## Requirements
661+
Using this solver requires that blis_jll is available and the BLIS extension is loaded.
662+
The solver will be automatically available when conditions are met.
663+
664+
## Performance Notes
665+
- Optimized for modern CPU architectures with BLIS-specific optimizations
666+
- May provide better performance than standard BLAS on certain processors
667+
- Best suited for dense matrices with Float32, Float64, ComplexF32, or ComplexF64 elements
668+
669+
## Example
670+
```julia
671+
alg = BLISLUFactorization()
672+
sol = solve(prob, alg)
673+
```
674+
"""
675+
struct BLISLUFactorization <: AbstractFactorization
676+
function BLISLUFactorization(; throwerror = true)
677+
ext = Base.get_extension(@__MODULE__, :LinearSolveBLISExt)
678+
if ext === nothing && throwerror
679+
error("BLISLUFactorization requires that the BLIS extension is loaded and blis_jll is available")
680+
else
681+
return new()
682+
end
683+
end
684+
end
623685

624686
"""
625687
`CUSOLVERRFFactorization(; symbolic = :RF, reuse_symbolic = true)`

src/factorization.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,23 @@ function init_cacheval(::CliqueTreesFactorization, ::StaticArray, b, u, Pl, Pr,
12031203
nothing
12041204
end
12051205

1206+
# Fallback init_cacheval for extension-based algorithms when extensions aren't loaded
1207+
# These return nothing since the actual implementations are in the extensions
1208+
function init_cacheval(::BLISLUFactorization, A, b, u, Pl, Pr,
1209+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1210+
nothing
1211+
end
1212+
1213+
function init_cacheval(::CudaOffloadLUFactorization, A, b, u, Pl, Pr,
1214+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1215+
nothing
1216+
end
1217+
1218+
function init_cacheval(::MetalLUFactorization, A, b, u, Pl, Pr,
1219+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1220+
nothing
1221+
end
1222+
12061223
for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
12071224
InteractiveUtils.subtypes(AbstractSparseFactorization))
12081225
@eval function init_cacheval(alg::$alg, A::MatrixOperator, b, u, Pl, Pr,

src/preferences.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ function _string_to_algorithm_choice(algorithm_name::Union{String, Nothing})
2222
elseif algorithm_name == "FastLUFactorization"
2323
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (FastLapack extension)
2424
elseif algorithm_name == "BLISLUFactorization"
25-
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (BLIS extension)
25+
return DefaultAlgorithmChoice.BLISLUFactorization # Now supported as a separate choice
2626
elseif algorithm_name == "CudaOffloadLUFactorization"
27-
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (CUDA extension)
27+
return DefaultAlgorithmChoice.CudaOffloadLUFactorization # Now supported as a separate choice
2828
elseif algorithm_name == "MetalLUFactorization"
29-
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (Metal extension)
29+
return DefaultAlgorithmChoice.MetalLUFactorization # Now supported as a separate choice
3030
elseif algorithm_name == "AMDGPUOffloadLUFactorization"
3131
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (AMDGPU extension)
3232
else

test/nopre/jet.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ end
6363

6464
# CUDA/Metal factorizations (only test if CUDA/Metal are loaded)
6565
# CudaOffloadFactorization requires CUDA to be loaded, skip if not available
66-
if @isdefined(MetalLUFactorization)
66+
# Metal is only available on Apple platforms
67+
if Sys.isapple() && @isdefined(MetalLUFactorization)
6768
JET.@test_opt solve(prob, MetalLUFactorization()) broken=true
6869
end
6970
if @isdefined(BLISLUFactorization)

test_new_solvers.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using Pkg
2+
Pkg.activate(".")
3+
using LinearSolve
4+
using Test
5+
using LinearAlgebra
6+
7+
# Test that the new algorithm choices are available in the enum
8+
@testset "New Algorithm Choices" begin
9+
choices = Symbol.(instances(LinearSolve.DefaultAlgorithmChoice.T))
10+
println("Available choices: ", choices)
11+
@test :BLISLUFactorization in choices
12+
@test :CudaOffloadLUFactorization in choices
13+
@test :MetalLUFactorization in choices
14+
end
15+
16+
# Test that availability checking functions exist
17+
@testset "Availability Functions" begin
18+
# These should return false since the extensions aren't loaded
19+
@test LinearSolve.useblis() == false
20+
@test LinearSolve.usecuda() == false
21+
@test LinearSolve.usemetal() == false
22+
23+
# Test that is_algorithm_available correctly reports availability
24+
@test LinearSolve.is_algorithm_available(LinearSolve.DefaultAlgorithmChoice.BLISLUFactorization) == false
25+
@test LinearSolve.is_algorithm_available(LinearSolve.DefaultAlgorithmChoice.CudaOffloadLUFactorization) == false
26+
@test LinearSolve.is_algorithm_available(LinearSolve.DefaultAlgorithmChoice.MetalLUFactorization) == false
27+
end
28+
29+
# Test that the algorithms can be instantiated without extensions (with throwerror=false)
30+
@testset "Algorithm Instantiation" begin
31+
# These should work with throwerror=false
32+
alg1 = LinearSolve.BLISLUFactorization(throwerror=false)
33+
@test alg1 isa LinearSolve.BLISLUFactorization
34+
35+
alg2 = LinearSolve.CudaOffloadLUFactorization(throwerror=false)
36+
@test alg2 isa LinearSolve.CudaOffloadLUFactorization
37+
38+
# Metal is only available on Apple platforms
39+
if Sys.isapple()
40+
alg3 = LinearSolve.MetalLUFactorization(throwerror=false)
41+
@test alg3 isa LinearSolve.MetalLUFactorization
42+
else
43+
# On non-Apple platforms, it should still not error with throwerror=false
44+
alg3 = LinearSolve.MetalLUFactorization(throwerror=false)
45+
@test alg3 isa LinearSolve.MetalLUFactorization
46+
end
47+
48+
# These should throw errors with throwerror=true (default)
49+
@test_throws ErrorException LinearSolve.BLISLUFactorization()
50+
@test_throws ErrorException LinearSolve.CudaOffloadLUFactorization()
51+
52+
# Metal error message depends on platform
53+
if Sys.isapple()
54+
@test_throws ErrorException LinearSolve.MetalLUFactorization()
55+
else
56+
# On non-Apple platforms, should error with platform message
57+
@test_throws ErrorException LinearSolve.MetalLUFactorization()
58+
end
59+
end
60+
61+
# Test that preferences system recognizes the new algorithms
62+
@testset "Preferences Support" begin
63+
# Test that the preference string mapping works
64+
alg = LinearSolve._string_to_algorithm_choice("BLISLUFactorization")
65+
@test alg === LinearSolve.DefaultAlgorithmChoice.BLISLUFactorization
66+
67+
alg = LinearSolve._string_to_algorithm_choice("CudaOffloadLUFactorization")
68+
@test alg === LinearSolve.DefaultAlgorithmChoice.CudaOffloadLUFactorization
69+
70+
alg = LinearSolve._string_to_algorithm_choice("MetalLUFactorization")
71+
@test alg === LinearSolve.DefaultAlgorithmChoice.MetalLUFactorization
72+
end
73+
74+
# Test basic solve still works with DefaultLinearSolver
75+
@testset "Default Solver Still Works" begin
76+
A = rand(10, 10)
77+
b = rand(10)
78+
prob = LinearProblem(A, b)
79+
80+
# Should use default solver and work fine
81+
sol = solve(prob)
82+
@test sol.retcode == ReturnCode.Success
83+
@test norm(A * sol.u - b) < 1e-10
84+
end
85+
86+
println("All tests passed!")

0 commit comments

Comments
 (0)