Skip to content

Commit 8fd2724

Browse files
Add BLISLUFactorization, CudaOffloadLUFactorization, and MetalLUFactorization to default solver choices
- Added new algorithm choices to DefaultAlgorithmChoice enum - Implemented conditional availability checking for new solvers - Added throwerror parameter to constructors for compatibility with default solver - Added fallback init_cacheval implementations for when extensions aren't loaded - Updated preferences system to recognize new algorithm names - Added availability checking functions (useblis, usecuda, usemetal) - Updated DefaultLinearSolverInit struct to include new algorithms - Added handling in solve! function for new algorithms with proper extension checks These solvers will only be selected by the default algorithm if: 1. They are available (extensions loaded) 2. They are specified in preferences from autotuning Modeled implementation after RFLUFactorization pattern with conditional availability. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 2024295 commit 8fd2724

File tree

6 files changed

+238
-13
lines changed

6 files changed

+238
-13
lines changed

src/LinearSolve.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ EnumX.@enumx DefaultAlgorithmChoice begin
275275
QRFactorizationPivoted
276276
KrylovJL_CRAIGMR
277277
KrylovJL_LSMR
278+
BLISLUFactorization
279+
CudaOffloadLUFactorization
280+
MetalLUFactorization
278281
end
279282

280283
# Autotune preference constants - loaded once at package import time
@@ -296,12 +299,17 @@ function is_algorithm_available(alg::DefaultAlgorithmChoice.T)
296299
return appleaccelerate_isavailable() # Available on macOS with Accelerate
297300
elseif alg === DefaultAlgorithmChoice.RFLUFactorization
298301
return userecursivefactorization(nothing) # Requires RecursiveFactorization extension
302+
elseif alg === DefaultAlgorithmChoice.BLISLUFactorization
303+
return useblis() # Available if BLIS extension is loaded
304+
elseif alg === DefaultAlgorithmChoice.CudaOffloadLUFactorization
305+
return usecuda() # Available if CUDA extension is loaded
306+
elseif alg === DefaultAlgorithmChoice.MetalLUFactorization
307+
return usemetal() # Available if Metal extension is loaded
299308
else
300309
# For extension-dependent algorithms not explicitly handled above,
301310
# we cannot easily check availability without trying to use them.
302311
# For now, assume they're not available in the default selection.
303-
# This includes FastLU, BLIS, CUDA, Metal, etc. which would require
304-
# specific extension checks.
312+
# This includes other extensions that might be added in the future.
305313
return false
306314
end
307315
end
@@ -399,6 +407,11 @@ isopenblas() = IS_OPENBLAS[]
399407
const HAS_APPLE_ACCELERATE = Ref(false)
400408
appleaccelerate_isavailable() = HAS_APPLE_ACCELERATE[]
401409

410+
# Extension availability checking functions
411+
useblis() = Base.get_extension(@__MODULE__, :LinearSolveBLISExt) !== nothing
412+
usecuda() = Base.get_extension(@__MODULE__, :LinearSolveCUDAExt) !== nothing
413+
usemetal() = Base.get_extension(@__MODULE__, :LinearSolveMetalExt) !== nothing
414+
402415
PrecompileTools.@compile_workload begin
403416
A = rand(4, 4)
404417
b = rand(4)

src/default.jl

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
needs_concrete_A(alg::DefaultLinearSolver) = true
22
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3-
T13, T14, T15, T16, T17, T18, T19, T20, T21}
3+
T13, T14, T15, T16, T17, T18, T19, T20, T21, T22, T23, T24}
44
LUFactorization::T1
55
QRFactorization::T2
66
DiagonalFactorization::T3
@@ -22,6 +22,9 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
2222
QRFactorizationPivoted::T19
2323
KrylovJL_CRAIGMR::T20
2424
KrylovJL_LSMR::T21
25+
BLISLUFactorization::T22
26+
CudaOffloadLUFactorization::T23
27+
MetalLUFactorization::T24
2528
end
2629

2730
@generated function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v)
@@ -422,6 +425,12 @@ function algchoice_to_alg(alg::Symbol)
422425
KrylovJL_CRAIGMR()
423426
elseif alg === :KrylovJL_LSMR
424427
KrylovJL_LSMR()
428+
elseif alg === :BLISLUFactorization
429+
BLISLUFactorization(throwerror = false)
430+
elseif alg === :CudaOffloadLUFactorization
431+
CudaOffloadLUFactorization(throwerror = false)
432+
elseif alg === :MetalLUFactorization
433+
MetalLUFactorization(throwerror = false)
425434
else
426435
error("Algorithm choice symbol $alg not allowed in the default")
427436
end
@@ -526,6 +535,66 @@ end
526535
error("Default algorithm calling solve on RecursiveFactorization without the package being loaded. This shouldn't happen.")
527536
end
528537

538+
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
539+
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
540+
## TODO: Add verbosity logging here about using the fallback
541+
sol = SciMLBase.solve!(
542+
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
543+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
544+
retcode = sol.retcode,
545+
iters = sol.iters, stats = sol.stats)
546+
else
547+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
548+
retcode = sol.retcode,
549+
iters = sol.iters, stats = sol.stats)
550+
end
551+
end
552+
elseif alg == Symbol(DefaultAlgorithmChoice.BLISLUFactorization)
553+
newex = quote
554+
if !useblis()
555+
error("Default algorithm calling solve on BLISLUFactorization without the extension being loaded. This shouldn't happen.")
556+
end
557+
558+
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
559+
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
560+
## TODO: Add verbosity logging here about using the fallback
561+
sol = SciMLBase.solve!(
562+
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
563+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
564+
retcode = sol.retcode,
565+
iters = sol.iters, stats = sol.stats)
566+
else
567+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
568+
retcode = sol.retcode,
569+
iters = sol.iters, stats = sol.stats)
570+
end
571+
end
572+
elseif alg == Symbol(DefaultAlgorithmChoice.CudaOffloadLUFactorization)
573+
newex = quote
574+
if !usecuda()
575+
error("Default algorithm calling solve on CudaOffloadLUFactorization without CUDA.jl being loaded. This shouldn't happen.")
576+
end
577+
578+
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
579+
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
580+
## TODO: Add verbosity logging here about using the fallback
581+
sol = SciMLBase.solve!(
582+
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
583+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
584+
retcode = sol.retcode,
585+
iters = sol.iters, stats = sol.stats)
586+
else
587+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
588+
retcode = sol.retcode,
589+
iters = sol.iters, stats = sol.stats)
590+
end
591+
end
592+
elseif alg == Symbol(DefaultAlgorithmChoice.MetalLUFactorization)
593+
newex = quote
594+
if !usemetal()
595+
error("Default algorithm calling solve on MetalLUFactorization without Metal.jl being loaded. This shouldn't happen.")
596+
end
597+
529598
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
530599
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
531600
## TODO: Add verbosity logging here about using the fallback

src/extension_algs.jl

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ Requires a sufficiently large `A` to overcome the data transfer costs.
7373
Using this solver requires adding the package CUDA.jl, i.e. `using CUDA`
7474
"""
7575
struct CudaOffloadLUFactorization <: AbstractFactorization
76-
function CudaOffloadLUFactorization()
76+
function CudaOffloadLUFactorization(; throwerror = true)
7777
ext = Base.get_extension(@__MODULE__, :LinearSolveCUDAExt)
78-
if ext === nothing
78+
if ext === nothing && throwerror
7979
error("CudaOffloadLUFactorization requires that CUDA is loaded, i.e. `using CUDA`")
8080
else
8181
return new()
@@ -610,16 +610,70 @@ A wrapper over the IterativeSolvers.jl MINRES.
610610
function IterativeSolversJL_MINRES end
611611

612612
"""
613+
MetalLUFactorization()
614+
615+
A wrapper over Apple's Metal GPU library for LU factorization. Direct calls to Metal
616+
in a way that pre-allocates workspace to avoid allocations and automatically offloads
617+
to the GPU. This solver is optimized for Metal-capable Apple Silicon Macs.
618+
619+
## Requirements
620+
Using this solver requires that Metal.jl is loaded: `using Metal`
621+
622+
## Performance Notes
623+
- Most efficient for large dense matrices where GPU acceleration benefits outweigh transfer costs
624+
- Automatically manages GPU memory and transfers
625+
- Particularly effective on Apple Silicon Macs with unified memory
626+
627+
## Example
613628
```julia
614-
MetalLUFactorization()
629+
using Metal
630+
alg = MetalLUFactorization()
631+
sol = solve(prob, alg)
615632
```
633+
"""
634+
struct MetalLUFactorization <: AbstractFactorization
635+
function MetalLUFactorization(; throwerror = true)
636+
ext = Base.get_extension(@__MODULE__, :LinearSolveMetalExt)
637+
if ext === nothing && throwerror
638+
error("MetalLUFactorization requires that Metal.jl is loaded, i.e. `using Metal`")
639+
else
640+
return new()
641+
end
642+
end
643+
end
616644

617-
A wrapper over Apple's Metal GPU library. Direct calls to Metal in a way that pre-allocates workspace
618-
to avoid allocations and automatically offloads to the GPU.
619645
"""
620-
struct MetalLUFactorization <: AbstractFactorization end
646+
BLISLUFactorization()
621647
622-
struct BLISLUFactorization <: AbstractFactorization end
648+
An LU factorization implementation using the BLIS (BLAS-like Library Instantiation Software)
649+
framework. BLIS provides high-performance dense linear algebra kernels optimized for various
650+
CPU architectures.
651+
652+
## Requirements
653+
Using this solver requires that blis_jll is available and the BLIS extension is loaded.
654+
The solver will be automatically available when conditions are met.
655+
656+
## Performance Notes
657+
- Optimized for modern CPU architectures with BLIS-specific optimizations
658+
- May provide better performance than standard BLAS on certain processors
659+
- Best suited for dense matrices with Float32, Float64, ComplexF32, or ComplexF64 elements
660+
661+
## Example
662+
```julia
663+
alg = BLISLUFactorization()
664+
sol = solve(prob, alg)
665+
```
666+
"""
667+
struct BLISLUFactorization <: AbstractFactorization
668+
function BLISLUFactorization(; throwerror = true)
669+
ext = Base.get_extension(@__MODULE__, :LinearSolveBLISExt)
670+
if ext === nothing && throwerror
671+
error("BLISLUFactorization requires that the BLIS extension is loaded and blis_jll is available")
672+
else
673+
return new()
674+
end
675+
end
676+
end
623677

624678
"""
625679
`CUSOLVERRFFactorization(; symbolic = :RF, reuse_symbolic = true)`

src/factorization.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,23 @@ function init_cacheval(::CliqueTreesFactorization, ::StaticArray, b, u, Pl, Pr,
12031203
nothing
12041204
end
12051205

1206+
# Fallback init_cacheval for extension-based algorithms when extensions aren't loaded
1207+
# These return nothing since the actual implementations are in the extensions
1208+
function init_cacheval(::BLISLUFactorization, A, b, u, Pl, Pr,
1209+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1210+
nothing
1211+
end
1212+
1213+
function init_cacheval(::CudaOffloadLUFactorization, A, b, u, Pl, Pr,
1214+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1215+
nothing
1216+
end
1217+
1218+
function init_cacheval(::MetalLUFactorization, A, b, u, Pl, Pr,
1219+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1220+
nothing
1221+
end
1222+
12061223
for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
12071224
InteractiveUtils.subtypes(AbstractSparseFactorization))
12081225
@eval function init_cacheval(alg::$alg, A::MatrixOperator, b, u, Pl, Pr,

src/preferences.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ function _string_to_algorithm_choice(algorithm_name::Union{String, Nothing})
2222
elseif algorithm_name == "FastLUFactorization"
2323
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (FastLapack extension)
2424
elseif algorithm_name == "BLISLUFactorization"
25-
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (BLIS extension)
25+
return DefaultAlgorithmChoice.BLISLUFactorization # Now supported as a separate choice
2626
elseif algorithm_name == "CudaOffloadLUFactorization"
27-
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (CUDA extension)
27+
return DefaultAlgorithmChoice.CudaOffloadLUFactorization # Now supported as a separate choice
2828
elseif algorithm_name == "MetalLUFactorization"
29-
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (Metal extension)
29+
return DefaultAlgorithmChoice.MetalLUFactorization # Now supported as a separate choice
3030
elseif algorithm_name == "AMDGPUOffloadLUFactorization"
3131
return DefaultAlgorithmChoice.LUFactorization # Map to standard LU (AMDGPU extension)
3232
else

test_new_solvers.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using Pkg
2+
Pkg.activate(".")
3+
using LinearSolve
4+
using Test
5+
using LinearAlgebra
6+
7+
# Test that the new algorithm choices are available in the enum
8+
@testset "New Algorithm Choices" begin
9+
choices = Symbol.(instances(LinearSolve.DefaultAlgorithmChoice.T))
10+
println("Available choices: ", choices)
11+
@test :BLISLUFactorization in choices
12+
@test :CudaOffloadLUFactorization in choices
13+
@test :MetalLUFactorization in choices
14+
end
15+
16+
# Test that availability checking functions exist
17+
@testset "Availability Functions" begin
18+
# These should return false since the extensions aren't loaded
19+
@test LinearSolve.useblis() == false
20+
@test LinearSolve.usecuda() == false
21+
@test LinearSolve.usemetal() == false
22+
23+
# Test that is_algorithm_available correctly reports availability
24+
@test LinearSolve.is_algorithm_available(LinearSolve.DefaultAlgorithmChoice.BLISLUFactorization) == false
25+
@test LinearSolve.is_algorithm_available(LinearSolve.DefaultAlgorithmChoice.CudaOffloadLUFactorization) == false
26+
@test LinearSolve.is_algorithm_available(LinearSolve.DefaultAlgorithmChoice.MetalLUFactorization) == false
27+
end
28+
29+
# Test that the algorithms can be instantiated without extensions (with throwerror=false)
30+
@testset "Algorithm Instantiation" begin
31+
# These should work with throwerror=false
32+
alg1 = LinearSolve.BLISLUFactorization(throwerror=false)
33+
@test alg1 isa LinearSolve.BLISLUFactorization
34+
35+
alg2 = LinearSolve.CudaOffloadLUFactorization(throwerror=false)
36+
@test alg2 isa LinearSolve.CudaOffloadLUFactorization
37+
38+
alg3 = LinearSolve.MetalLUFactorization(throwerror=false)
39+
@test alg3 isa LinearSolve.MetalLUFactorization
40+
41+
# These should throw errors with throwerror=true (default)
42+
@test_throws ErrorException LinearSolve.BLISLUFactorization()
43+
@test_throws ErrorException LinearSolve.CudaOffloadLUFactorization()
44+
@test_throws ErrorException LinearSolve.MetalLUFactorization()
45+
end
46+
47+
# Test that preferences system recognizes the new algorithms
48+
@testset "Preferences Support" begin
49+
# Test that the preference string mapping works
50+
alg = LinearSolve._string_to_algorithm_choice("BLISLUFactorization")
51+
@test alg === LinearSolve.DefaultAlgorithmChoice.BLISLUFactorization
52+
53+
alg = LinearSolve._string_to_algorithm_choice("CudaOffloadLUFactorization")
54+
@test alg === LinearSolve.DefaultAlgorithmChoice.CudaOffloadLUFactorization
55+
56+
alg = LinearSolve._string_to_algorithm_choice("MetalLUFactorization")
57+
@test alg === LinearSolve.DefaultAlgorithmChoice.MetalLUFactorization
58+
end
59+
60+
# Test basic solve still works with DefaultLinearSolver
61+
@testset "Default Solver Still Works" begin
62+
A = rand(10, 10)
63+
b = rand(10)
64+
prob = LinearProblem(A, b)
65+
66+
# Should use default solver and work fine
67+
sol = solve(prob)
68+
@test sol.retcode == ReturnCode.Success
69+
@test norm(A * sol.u - b) < 1e-10
70+
end
71+
72+
println("All tests passed!")

0 commit comments

Comments
 (0)