Skip to content

Commit ea53e7f

Browse files
Merge branch 'main' into add-openblas-lu-factorization
2 parents d4da31b + 71315f6 commit ea53e7f

19 files changed

+477
-12
lines changed

docs/src/solvers/solvers.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,24 @@ this is only recommended for Float32 matrices. Choose `CudaOffloadLUFactorizatio
3434
performance on well-conditioned problems, or `CudaOffloadQRFactorization` for better numerical
3535
stability on ill-conditioned problems.
3636

37+
#### Mixed Precision Methods
38+
39+
For large well-conditioned problems where memory bandwidth is the bottleneck, mixed precision
40+
methods can provide significant speedups (up to 2x) by performing the factorization in Float32
41+
while maintaining Float64 interfaces. These methods are particularly effective for:
42+
- Large dense matrices (> 1000x1000)
43+
- Well-conditioned problems (condition number < 10^4)
44+
- Hardware with good Float32 performance
45+
46+
Available mixed precision solvers:
47+
- `MKL32MixedLUFactorization` - CPUs with MKL
48+
- `AppleAccelerate32MixedLUFactorization` - Apple CPUs with Accelerate
49+
- `CUDAOffload32MixedLUFactorization` - NVIDIA GPUs with CUDA
50+
- `MetalOffload32MixedLUFactorization` - Apple GPUs with Metal
51+
52+
These methods automatically handle the precision conversion, making them easy drop-in replacements
53+
when reduced precision is acceptable for the factorization step.
54+
3755
!!! note
3856

3957
Performance details for dense LU-factorizations can be highly dependent on the hardware configuration.
@@ -207,6 +225,7 @@ KrylovJL
207225

208226
```@docs
209227
MKLLUFactorization
228+
MKL32MixedLUFactorization
210229
```
211230

212231
### OpenBLAS
@@ -223,6 +242,7 @@ OpenBLASLUFactorization
223242

224243
```@docs
225244
AppleAccelerateLUFactorization
245+
AppleAccelerate32MixedLUFactorization
226246
```
227247

228248
### Metal.jl
@@ -233,6 +253,7 @@ AppleAccelerateLUFactorization
233253

234254
```@docs
235255
MetalLUFactorization
256+
MetalOffload32MixedLUFactorization
236257
```
237258

238259
### Pardiso.jl
@@ -259,6 +280,7 @@ The following are non-standard GPU factorization routines.
259280
```@docs
260281
CudaOffloadLUFactorization
261282
CudaOffloadQRFactorization
283+
CUDAOffload32MixedLUFactorization
262284
```
263285

264286
### AMDGPU.jl

ext/LinearSolveAMDGPUExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module LinearSolveAMDGPUExt
22

33
using AMDGPU
44
using LinearSolve: LinearSolve, LinearCache, AMDGPUOffloadLUFactorization,
5-
AMDGPUOffloadQRFactorization, init_cacheval, OperatorAssumptions
5+
AMDGPUOffloadQRFactorization, init_cacheval, OperatorAssumptions,
6+
LinearVerbosity
67
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase
78

89
# LU Factorization
@@ -65,4 +66,4 @@ function LinearSolve.init_cacheval(alg::AMDGPUOffloadQRFactorization, A, b, u, P
6566
(A_gpu, tau)
6667
end
6768

68-
end
69+
end

ext/LinearSolveBLISExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using LinearSolve
99
using LinearAlgebra: BlasInt, LU
1010
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
1111
@blasfunc, chkargsok
12-
using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase
12+
using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase, LinearVerbosity
1313
using SciMLBase: ReturnCode
1414

1515
const global libblis = blis_jll.blis

ext/LinearSolveBandedMatricesExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module LinearSolveBandedMatricesExt
33
using BandedMatrices, LinearAlgebra, LinearSolve
44
import LinearSolve: defaultalg,
55
do_factorization, init_cacheval, DefaultLinearSolver,
6-
DefaultAlgorithmChoice
6+
DefaultAlgorithmChoice, LinearVerbosity
77

88
# Defaults for BandedMatrices
99
function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions{Bool})

ext/LinearSolveCUDAExt.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
77
needs_concrete_A,
88
error_no_cudss_lu, init_cacheval, OperatorAssumptions,
99
CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
10+
CUDAOffload32MixedLUFactorization,
1011
SparspakFactorization, KLUFactorization, UMFPACKFactorization,
1112
LinearVerbosity
1213
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
@@ -52,7 +53,7 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact
5253
SciMLBase.build_linear_solution(alg, y, nothing, cache)
5354
end
5455

55-
function LinearSolve.init_cacheval(alg::CudaOffloadLUFactorization, A, b, u, Pl, Pr,
56+
function LinearSolve.init_cacheval(alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
5657
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
5758
assumptions::OperatorAssumptions)
5859
T = eltype(A)
@@ -94,7 +95,7 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
9495
SciMLBase.build_linear_solution(alg, y, nothing, cache)
9596
end
9697

97-
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
98+
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
9899
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
99100
assumptions::OperatorAssumptions)
100101
qr(CUDA.CuArray(A))
@@ -118,4 +119,40 @@ function LinearSolve.init_cacheval(
118119
nothing
119120
end
120121

122+
# Mixed precision CUDA LU implementation
123+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
124+
kwargs...)
125+
if cache.isfresh
126+
cacheval = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
127+
# Convert to Float32 for factorization
128+
A_f32 = Float32.(cache.A)
129+
fact = lu(CUDA.CuArray(A_f32))
130+
cache.cacheval = fact
131+
cache.isfresh = false
132+
end
133+
fact = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
134+
# Convert b to Float32, solve, then convert back to original precision
135+
b_f32 = Float32.(cache.b)
136+
u_f32 = CUDA.CuArray(b_f32)
137+
y_f32 = ldiv!(u_f32, fact, CUDA.CuArray(b_f32))
138+
# Convert back to original precision
139+
y = Array(y_f32)
140+
T = eltype(cache.u)
141+
cache.u .= T.(y)
142+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
143+
end
144+
145+
function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b, u, Pl, Pr,
146+
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
147+
assumptions::OperatorAssumptions)
148+
# Pre-allocate with Float32 arrays
149+
A_f32 = Float32.(A)
150+
T = eltype(A_f32)
151+
noUnitT = typeof(zero(T))
152+
luT = LinearAlgebra.lutype(noUnitT)
153+
ipiv = CuVector{Int32}(undef, 0)
154+
info = zero(LinearAlgebra.BlasInt)
155+
return LU{luT}(CuMatrix{Float32}(undef, 0, 0), ipiv, info)
156+
end
157+
121158
end

ext/LinearSolveCUSOLVERRFExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module LinearSolveCUSOLVERRFExt
22

3-
using LinearSolve: LinearSolve, @get_cacheval, pattern_changed, OperatorAssumptions
3+
using LinearSolve: LinearSolve, @get_cacheval, pattern_changed, OperatorAssumptions, LinearVerbosity
44
using CUSOLVERRF: CUSOLVERRF, RFLU, CUDA
55
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz
66
using CUSOLVERRF.CUDA.CUSPARSE: CuSparseMatrixCSR
@@ -86,4 +86,4 @@ function LinearSolve.pattern_changed(rf::RFLU, A::CuSparseMatrixCSR)
8686
end
8787

8888

89-
end
89+
end

ext/LinearSolveCliqueTreesExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module LinearSolveCliqueTreesExt
22

33
using CliqueTrees: symbolic, cholinit, lininit, cholesky!, linsolve!
44
using LinearSolve
5+
using LinearSolve: LinearVerbosity
56
using SparseArrays
67

78
function _symbolic(A::AbstractMatrix, alg::CliqueTreesFactorization)

ext/LinearSolveFastAlmostBandedMatricesExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module LinearSolveFastAlmostBandedMatricesExt
33
using FastAlmostBandedMatrices, LinearAlgebra, LinearSolve
44
import LinearSolve: defaultalg,
55
do_factorization, init_cacheval, DefaultLinearSolver,
6-
DefaultAlgorithmChoice
6+
DefaultAlgorithmChoice, LinearVerbosity
77

88
function defaultalg(A::AlmostBandedMatrix, b, oa::OperatorAssumptions{Bool})
99
if oa.issq

ext/LinearSolveFastLapackInterfaceExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LinearSolveFastLapackInterfaceExt
22

33
using LinearSolve, LinearAlgebra
4+
using LinearSolve: LinearVerbosity
45
using FastLapackInterface
56

67
struct WorkspaceAndFactors{W, F}

ext/LinearSolveIterativeSolversExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module LinearSolveIterativeSolversExt
22

33
using LinearSolve, LinearAlgebra
4-
using LinearSolve: LinearCache, DEFAULT_PRECS
4+
using LinearSolve: LinearCache, DEFAULT_PRECS, LinearVerbosity
55
import LinearSolve: IterativeSolversJL
66
using SciMLLogging: @SciMLMessage, Verbosity
77

0 commit comments

Comments
 (0)