Skip to content

Commit 42ef6f2

Browse files
ChrisRackauckas-ClaudeclaudeChrisRackauckas
authored
Add mixed precision LU factorization methods (#746)
* Add mixed precision LU factorization methods This commit introduces four new mixed precision LU factorization algorithms that perform computations in Float32 while maintaining Float64 interfaces, providing significant performance improvements for memory-bandwidth limited problems. New factorization methods: - CUDAOffload32MixedLUFactorization: GPU-accelerated mixed precision for NVIDIA GPUs - MetalOffload32MixedLUFactorization: GPU-accelerated mixed precision for Apple Metal - MKL32MixedLUFactorization: CPU-based mixed precision using Intel MKL - AppleAccelerate32MixedLUFactorization: CPU-based mixed precision using Apple Accelerate Key features: - Transparent Float64 to Float32 conversion for factorization - Support for both real and complex matrices - Up to 2x speedup for large, well-conditioned matrices - Maintains reasonable accuracy while reducing memory bandwidth requirements The implementations handle precision conversion internally, making them easy to use as drop-in replacements for standard LU factorization when reduced precision is acceptable. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Add tests and documentation for mixed precision methods - Added mixed precision tests to the Core test group in runtests.jl - Added documentation for all four mixed precision methods in docs - Added section explaining when to use mixed precision methods - Documentation includes performance characteristics and use cases The tests now run as part of the standard test suite, and the documentation provides clear guidance on when these methods are beneficial (large well-conditioned problems with memory bandwidth bottlenecks). 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Update docs/src/solvers/solvers.md --------- Co-authored-by: Claude <[email protected]> Co-authored-by: Christopher Rackauckas <[email protected]>
1 parent 99c54ec commit 42ef6f2

File tree

9 files changed

+462
-1
lines changed

9 files changed

+462
-1
lines changed

docs/src/solvers/solvers.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,24 @@ this is only recommended for Float32 matrices. Choose `CudaOffloadLUFactorizatio
3232
performance on well-conditioned problems, or `CudaOffloadQRFactorization` for better numerical
3333
stability on ill-conditioned problems.
3434

35+
#### Mixed Precision Methods
36+
37+
For large well-conditioned problems where memory bandwidth is the bottleneck, mixed precision
38+
methods can provide significant speedups (up to 2x) by performing the factorization in Float32
39+
while maintaining Float64 interfaces. These methods are particularly effective for:
40+
- Large dense matrices (> 1000x1000)
41+
- Well-conditioned problems (condition number < 10^4)
42+
- Hardware with good Float32 performance
43+
44+
Available mixed precision solvers:
45+
- `MKL32MixedLUFactorization` - CPUs with MKL
46+
- `AppleAccelerate32MixedLUFactorization` - Apple CPUs with Accelerate
47+
- `CUDAOffload32MixedLUFactorization` - NVIDIA GPUs with CUDA
48+
- `MetalOffload32MixedLUFactorization` - Apple GPUs with Metal
49+
50+
These methods automatically handle the precision conversion, making them easy drop-in replacements
51+
when reduced precision is acceptable for the factorization step.
52+
3553
!!! note
3654

3755
Performance details for dense LU-factorizations can be highly dependent on the hardware configuration.
@@ -205,6 +223,7 @@ KrylovJL
205223

206224
```@docs
207225
MKLLUFactorization
226+
MKL32MixedLUFactorization
208227
```
209228

210229
### AppleAccelerate.jl
@@ -215,6 +234,7 @@ MKLLUFactorization
215234

216235
```@docs
217236
AppleAccelerateLUFactorization
237+
AppleAccelerate32MixedLUFactorization
218238
```
219239

220240
### Metal.jl
@@ -225,6 +245,7 @@ AppleAccelerateLUFactorization
225245

226246
```@docs
227247
MetalLUFactorization
248+
MetalOffload32MixedLUFactorization
228249
```
229250

230251
### Pardiso.jl
@@ -251,6 +272,7 @@ The following are non-standard GPU factorization routines.
251272
```@docs
252273
CudaOffloadLUFactorization
253274
CudaOffloadQRFactorization
275+
CUDAOffload32MixedLUFactorization
254276
```
255277

256278
### AMDGPU.jl

ext/LinearSolveCUDAExt.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
77
needs_concrete_A,
88
error_no_cudss_lu, init_cacheval, OperatorAssumptions,
99
CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
10+
CUDAOffload32MixedLUFactorization,
1011
SparspakFactorization, KLUFactorization, UMFPACKFactorization,
1112
LinearVerbosity
1213
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
@@ -118,4 +119,40 @@ function LinearSolve.init_cacheval(
118119
nothing
119120
end
120121

122+
# Mixed precision CUDA LU implementation
123+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
124+
kwargs...)
125+
if cache.isfresh
126+
cacheval = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
127+
# Convert to Float32 for factorization
128+
A_f32 = Float32.(cache.A)
129+
fact = lu(CUDA.CuArray(A_f32))
130+
cache.cacheval = fact
131+
cache.isfresh = false
132+
end
133+
fact = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
134+
# Convert b to Float32, solve, then convert back to original precision
135+
b_f32 = Float32.(cache.b)
136+
u_f32 = CUDA.CuArray(b_f32)
137+
y_f32 = ldiv!(u_f32, fact, CUDA.CuArray(b_f32))
138+
# Convert back to original precision
139+
y = Array(y_f32)
140+
T = eltype(cache.u)
141+
cache.u .= T.(y)
142+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
143+
end
144+
145+
function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b, u, Pl, Pr,
146+
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
147+
assumptions::OperatorAssumptions)
148+
# Pre-allocate with Float32 arrays
149+
A_f32 = Float32.(A)
150+
T = eltype(A_f32)
151+
noUnitT = typeof(zero(T))
152+
luT = LinearAlgebra.lutype(noUnitT)
153+
ipiv = CuVector{Int32}(undef, 0)
154+
info = zero(LinearAlgebra.BlasInt)
155+
return LU{luT}(CuMatrix{Float32}(undef, 0, 0), ipiv, info)
156+
end
157+
121158
end

ext/LinearSolveMetalExt.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ module LinearSolveMetalExt
33
using Metal, LinearSolve
44
using LinearAlgebra, SciMLBase
55
using SciMLBase: AbstractSciMLOperator
6-
using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase
6+
using LinearSolve: ArrayInterface, MKLLUFactorization, MetalOffload32MixedLUFactorization,
7+
@get_cacheval, LinearCache, SciMLBase, OperatorAssumptions, LinearVerbosity
78

89
default_alias_A(::MetalLUFactorization, ::Any, ::Any) = false
910
default_alias_b(::MetalLUFactorization, ::Any, ::Any) = false
@@ -28,4 +29,45 @@ function SciMLBase.solve!(cache::LinearCache, alg::MetalLUFactorization;
2829
SciMLBase.build_linear_solution(alg, y, nothing, cache)
2930
end
3031

32+
# Mixed precision Metal LU implementation
33+
default_alias_A(::MetalOffload32MixedLUFactorization, ::Any, ::Any) = false
34+
default_alias_b(::MetalOffload32MixedLUFactorization, ::Any, ::Any) = false
35+
36+
function LinearSolve.init_cacheval(alg::MetalOffload32MixedLUFactorization, A, b, u, Pl, Pr,
37+
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
38+
assumptions::OperatorAssumptions)
39+
# Pre-allocate with Float32 arrays
40+
A_f32 = Float32.(convert(AbstractMatrix, A))
41+
ArrayInterface.lu_instance(A_f32)
42+
end
43+
44+
function SciMLBase.solve!(cache::LinearCache, alg::MetalOffload32MixedLUFactorization;
45+
kwargs...)
46+
A = cache.A
47+
A = convert(AbstractMatrix, A)
48+
if cache.isfresh
49+
cacheval = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
50+
# Convert to Float32 for factorization
51+
A_f32 = Float32.(A)
52+
res = lu(MtlArray(A_f32))
53+
# Store factorization on CPU with converted types
54+
cache.cacheval = LU(Array(res.factors), Array{Int}(res.ipiv), res.info)
55+
cache.isfresh = false
56+
end
57+
58+
fact = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
59+
# Convert b to Float32 for solving
60+
b_f32 = Float32.(cache.b)
61+
u_f32 = similar(b_f32)
62+
63+
# Create a temporary Float32 LU factorization for solving
64+
fact_f32 = LU(Float32.(fact.factors), fact.ipiv, fact.info)
65+
ldiv!(u_f32, fact_f32, b_f32)
66+
67+
# Convert back to original precision
68+
T = eltype(cache.u)
69+
cache.u .= T.(u_f32)
70+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
71+
end
72+
3173
end

src/LinearSolve.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,13 +456,17 @@ export HYPREAlgorithm
456456
export CudaOffloadFactorization
457457
export CudaOffloadLUFactorization
458458
export CudaOffloadQRFactorization
459+
export CUDAOffload32MixedLUFactorization
459460
export AMDGPUOffloadLUFactorization, AMDGPUOffloadQRFactorization
460461
export MKLPardisoFactorize, MKLPardisoIterate
461462
export PanuaPardisoFactorize, PanuaPardisoIterate
462463
export PardisoJL
463464
export MKLLUFactorization
465+
export MKL32MixedLUFactorization
464466
export AppleAccelerateLUFactorization
467+
export AppleAccelerate32MixedLUFactorization
465468
export MetalLUFactorization
469+
export MetalOffload32MixedLUFactorization
466470

467471
export OperatorAssumptions, OperatorCondition
468472

src/appleaccelerate.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ to avoid allocations and does not require libblastrampoline.
1414
"""
1515
struct AppleAccelerateLUFactorization <: AbstractFactorization end
1616

17+
1718
@static if !Sys.isapple()
1819
__appleaccelerate_isavailable() = false
1920
else
@@ -284,3 +285,84 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorizatio
284285
SciMLBase.build_linear_solution(
285286
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
286287
end
288+
289+
# Mixed precision AppleAccelerate implementation
290+
default_alias_A(::AppleAccelerate32MixedLUFactorization, ::Any, ::Any) = false
291+
default_alias_b(::AppleAccelerate32MixedLUFactorization, ::Any, ::Any) = false
292+
293+
const PREALLOCATED_APPLE32_LU = begin
294+
A = rand(Float32, 0, 0)
295+
luinst = ArrayInterface.lu_instance(A)
296+
LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}()
297+
end
298+
299+
function LinearSolve.init_cacheval(alg::AppleAccelerate32MixedLUFactorization, A, b, u, Pl, Pr,
300+
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
301+
assumptions::OperatorAssumptions)
302+
# Pre-allocate appropriate 32-bit arrays based on input type
303+
if eltype(A) <: Complex
304+
A_32 = rand(ComplexF32, 0, 0)
305+
else
306+
A_32 = rand(Float32, 0, 0)
307+
end
308+
luinst = ArrayInterface.lu_instance(A_32)
309+
LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}()
310+
end
311+
312+
function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFactorization;
313+
kwargs...)
314+
__appleaccelerate_isavailable() ||
315+
error("Error, AppleAccelerate binary is missing but solve is being called. Report this issue")
316+
A = cache.A
317+
A = convert(AbstractMatrix, A)
318+
319+
# Check if we have complex numbers
320+
iscomplex = eltype(A) <: Complex
321+
322+
if cache.isfresh
323+
cacheval = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
324+
# Convert to appropriate 32-bit type for factorization
325+
if iscomplex
326+
A_f32 = ComplexF32.(A)
327+
else
328+
A_f32 = Float32.(A)
329+
end
330+
res = aa_getrf!(A_f32; ipiv = cacheval[1].ipiv, info = cacheval[2])
331+
fact = LU(res[1:3]...), res[4]
332+
cache.cacheval = fact
333+
334+
if !LinearAlgebra.issuccess(fact[1])
335+
return SciMLBase.build_linear_solution(
336+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
337+
end
338+
cache.isfresh = false
339+
end
340+
341+
A_lu, info = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
342+
require_one_based_indexing(cache.u, cache.b)
343+
m, n = size(A_lu, 1), size(A_lu, 2)
344+
345+
# Convert b to appropriate 32-bit type for solving
346+
if iscomplex
347+
b_f32 = ComplexF32.(cache.b)
348+
else
349+
b_f32 = Float32.(cache.b)
350+
end
351+
352+
if m > n
353+
Bc = copy(b_f32)
354+
aa_getrs!('N', A_lu.factors, A_lu.ipiv, Bc; info)
355+
# Convert back to original precision
356+
T = eltype(cache.u)
357+
cache.u .= T.(Bc[1:n])
358+
else
359+
u_f32 = copy(b_f32)
360+
aa_getrs!('N', A_lu.factors, A_lu.ipiv, u_f32; info)
361+
# Convert back to original precision
362+
T = eltype(cache.u)
363+
cache.u .= T.(u_f32)
364+
end
365+
366+
SciMLBase.build_linear_solution(
367+
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
368+
end

0 commit comments

Comments
 (0)