Skip to content

Commit 0e70f68

Browse files
committed
Add mixed precision LU factorization methods
This commit introduces four new mixed precision LU factorization algorithms that perform computations in Float32 while maintaining Float64 interfaces, providing significant performance improvements for memory-bandwidth limited problems. New factorization methods: - CUDAOffload32MixedLUFactorization: GPU-accelerated mixed precision for NVIDIA GPUs - MetalOffload32MixedLUFactorization: GPU-accelerated mixed precision for Apple Metal - MKL32MixedLUFactorization: CPU-based mixed precision using Intel MKL - AppleAccelerate32MixedLUFactorization: CPU-based mixed precision using Apple Accelerate Key features: - Transparent Float64 to Float32 conversion for factorization - Support for both real and complex matrices - Up to 2x speedup for large, well-conditioned matrices - Maintains reasonable accuracy while reducing memory bandwidth requirements The implementations handle precision conversion internally, making them easy to use as drop-in replacements for standard LU factorization when reduced precision is acceptable. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 99c54ec commit 0e70f68

File tree

7 files changed

+439
-1
lines changed

7 files changed

+439
-1
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
77
needs_concrete_A,
88
error_no_cudss_lu, init_cacheval, OperatorAssumptions,
99
CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
10+
CUDAOffload32MixedLUFactorization,
1011
SparspakFactorization, KLUFactorization, UMFPACKFactorization,
1112
LinearVerbosity
1213
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
@@ -118,4 +119,40 @@ function LinearSolve.init_cacheval(
118119
nothing
119120
end
120121

122+
# Mixed precision CUDA LU implementation
123+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
124+
kwargs...)
125+
if cache.isfresh
126+
cacheval = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
127+
# Convert to Float32 for factorization
128+
A_f32 = Float32.(cache.A)
129+
fact = lu(CUDA.CuArray(A_f32))
130+
cache.cacheval = fact
131+
cache.isfresh = false
132+
end
133+
fact = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
134+
# Convert b to Float32, solve, then convert back to original precision
135+
b_f32 = Float32.(cache.b)
136+
u_f32 = CUDA.CuArray(b_f32)
137+
y_f32 = ldiv!(u_f32, fact, CUDA.CuArray(b_f32))
138+
# Convert back to original precision
139+
y = Array(y_f32)
140+
T = eltype(cache.u)
141+
cache.u .= T.(y)
142+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
143+
end
144+
145+
function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b, u, Pl, Pr,
146+
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
147+
assumptions::OperatorAssumptions)
148+
# Pre-allocate with Float32 arrays
149+
A_f32 = Float32.(A)
150+
T = eltype(A_f32)
151+
noUnitT = typeof(zero(T))
152+
luT = LinearAlgebra.lutype(noUnitT)
153+
ipiv = CuVector{Int32}(undef, 0)
154+
info = zero(LinearAlgebra.BlasInt)
155+
return LU{luT}(CuMatrix{Float32}(undef, 0, 0), ipiv, info)
156+
end
157+
121158
end

ext/LinearSolveMetalExt.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ module LinearSolveMetalExt
33
using Metal, LinearSolve
44
using LinearAlgebra, SciMLBase
55
using SciMLBase: AbstractSciMLOperator
6-
using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase
6+
using LinearSolve: ArrayInterface, MKLLUFactorization, MetalOffload32MixedLUFactorization,
7+
@get_cacheval, LinearCache, SciMLBase, OperatorAssumptions, LinearVerbosity
78

89
default_alias_A(::MetalLUFactorization, ::Any, ::Any) = false
910
default_alias_b(::MetalLUFactorization, ::Any, ::Any) = false
@@ -28,4 +29,45 @@ function SciMLBase.solve!(cache::LinearCache, alg::MetalLUFactorization;
2829
SciMLBase.build_linear_solution(alg, y, nothing, cache)
2930
end
3031

32+
# Mixed precision Metal LU implementation
33+
default_alias_A(::MetalOffload32MixedLUFactorization, ::Any, ::Any) = false
34+
default_alias_b(::MetalOffload32MixedLUFactorization, ::Any, ::Any) = false
35+
36+
function LinearSolve.init_cacheval(alg::MetalOffload32MixedLUFactorization, A, b, u, Pl, Pr,
37+
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
38+
assumptions::OperatorAssumptions)
39+
# Pre-allocate with Float32 arrays
40+
A_f32 = Float32.(convert(AbstractMatrix, A))
41+
ArrayInterface.lu_instance(A_f32)
42+
end
43+
44+
function SciMLBase.solve!(cache::LinearCache, alg::MetalOffload32MixedLUFactorization;
45+
kwargs...)
46+
A = cache.A
47+
A = convert(AbstractMatrix, A)
48+
if cache.isfresh
49+
cacheval = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
50+
# Convert to Float32 for factorization
51+
A_f32 = Float32.(A)
52+
res = lu(MtlArray(A_f32))
53+
# Store factorization on CPU with converted types
54+
cache.cacheval = LU(Array(res.factors), Array{Int}(res.ipiv), res.info)
55+
cache.isfresh = false
56+
end
57+
58+
fact = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
59+
# Convert b to Float32 for solving
60+
b_f32 = Float32.(cache.b)
61+
u_f32 = similar(b_f32)
62+
63+
# Create a temporary Float32 LU factorization for solving
64+
fact_f32 = LU(Float32.(fact.factors), fact.ipiv, fact.info)
65+
ldiv!(u_f32, fact_f32, b_f32)
66+
67+
# Convert back to original precision
68+
T = eltype(cache.u)
69+
cache.u .= T.(u_f32)
70+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
71+
end
72+
3173
end

src/LinearSolve.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,13 +456,17 @@ export HYPREAlgorithm
456456
export CudaOffloadFactorization
457457
export CudaOffloadLUFactorization
458458
export CudaOffloadQRFactorization
459+
export CUDAOffload32MixedLUFactorization
459460
export AMDGPUOffloadLUFactorization, AMDGPUOffloadQRFactorization
460461
export MKLPardisoFactorize, MKLPardisoIterate
461462
export PanuaPardisoFactorize, PanuaPardisoIterate
462463
export PardisoJL
463464
export MKLLUFactorization
465+
export MKL32MixedLUFactorization
464466
export AppleAccelerateLUFactorization
467+
export AppleAccelerate32MixedLUFactorization
465468
export MetalLUFactorization
469+
export MetalOffload32MixedLUFactorization
466470

467471
export OperatorAssumptions, OperatorCondition
468472

src/appleaccelerate.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ to avoid allocations and does not require libblastrampoline.
1414
"""
1515
struct AppleAccelerateLUFactorization <: AbstractFactorization end
1616

17+
1718
@static if !Sys.isapple()
1819
__appleaccelerate_isavailable() = false
1920
else
@@ -284,3 +285,84 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorizatio
284285
SciMLBase.build_linear_solution(
285286
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
286287
end
288+
289+
# Mixed precision AppleAccelerate implementation
290+
default_alias_A(::AppleAccelerate32MixedLUFactorization, ::Any, ::Any) = false
291+
default_alias_b(::AppleAccelerate32MixedLUFactorization, ::Any, ::Any) = false
292+
293+
const PREALLOCATED_APPLE32_LU = begin
294+
A = rand(Float32, 0, 0)
295+
luinst = ArrayInterface.lu_instance(A)
296+
LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}()
297+
end
298+
299+
function LinearSolve.init_cacheval(alg::AppleAccelerate32MixedLUFactorization, A, b, u, Pl, Pr,
300+
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
301+
assumptions::OperatorAssumptions)
302+
# Pre-allocate appropriate 32-bit arrays based on input type
303+
if eltype(A) <: Complex
304+
A_32 = rand(ComplexF32, 0, 0)
305+
else
306+
A_32 = rand(Float32, 0, 0)
307+
end
308+
luinst = ArrayInterface.lu_instance(A_32)
309+
LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}()
310+
end
311+
312+
function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFactorization;
313+
kwargs...)
314+
__appleaccelerate_isavailable() ||
315+
error("Error, AppleAccelerate binary is missing but solve is being called. Report this issue")
316+
A = cache.A
317+
A = convert(AbstractMatrix, A)
318+
319+
# Check if we have complex numbers
320+
iscomplex = eltype(A) <: Complex
321+
322+
if cache.isfresh
323+
cacheval = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
324+
# Convert to appropriate 32-bit type for factorization
325+
if iscomplex
326+
A_f32 = ComplexF32.(A)
327+
else
328+
A_f32 = Float32.(A)
329+
end
330+
res = aa_getrf!(A_f32; ipiv = cacheval[1].ipiv, info = cacheval[2])
331+
fact = LU(res[1:3]...), res[4]
332+
cache.cacheval = fact
333+
334+
if !LinearAlgebra.issuccess(fact[1])
335+
return SciMLBase.build_linear_solution(
336+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
337+
end
338+
cache.isfresh = false
339+
end
340+
341+
A_lu, info = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
342+
require_one_based_indexing(cache.u, cache.b)
343+
m, n = size(A_lu, 1), size(A_lu, 2)
344+
345+
# Convert b to appropriate 32-bit type for solving
346+
if iscomplex
347+
b_f32 = ComplexF32.(cache.b)
348+
else
349+
b_f32 = Float32.(cache.b)
350+
end
351+
352+
if m > n
353+
Bc = copy(b_f32)
354+
aa_getrs!('N', A_lu.factors, A_lu.ipiv, Bc; info)
355+
# Convert back to original precision
356+
T = eltype(cache.u)
357+
cache.u .= T.(Bc[1:n])
358+
else
359+
u_f32 = copy(b_f32)
360+
aa_getrs!('N', A_lu.factors, A_lu.ipiv, u_f32; info)
361+
# Convert back to original precision
362+
T = eltype(cache.u)
363+
cache.u .= T.(u_f32)
364+
end
365+
366+
SciMLBase.build_linear_solution(
367+
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
368+
end

src/extension_algs.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,35 @@ struct CudaOffloadLUFactorization <: AbstractFactorization
8383
end
8484
end
8585

86+
"""
87+
`CUDAOffload32MixedLUFactorization()`
88+
89+
A mixed precision GPU-accelerated LU factorization that converts matrices to Float32
90+
before offloading to CUDA GPU for factorization, then converts back for the solve.
91+
This can provide speedups when the reduced precision is acceptable and memory
92+
bandwidth is a bottleneck.
93+
94+
## Performance Notes
95+
- Converts Float64 matrices to Float32 for GPU factorization
96+
- Can be significantly faster for large matrices where memory bandwidth is limiting
97+
- May have reduced accuracy compared to full precision methods
98+
- Most beneficial when the condition number of the matrix is moderate
99+
100+
!!! note
101+
102+
Using this solver requires adding the package CUDA.jl, i.e. `using CUDA`
103+
"""
104+
struct CUDAOffload32MixedLUFactorization <: AbstractFactorization
105+
function CUDAOffload32MixedLUFactorization(; throwerror = true)
106+
ext = Base.get_extension(@__MODULE__, :LinearSolveCUDAExt)
107+
if ext === nothing && throwerror
108+
error("CUDAOffload32MixedLUFactorization requires that CUDA is loaded, i.e. `using CUDA`")
109+
else
110+
return new()
111+
end
112+
end
113+
end
114+
86115
"""
87116
`CudaOffloadQRFactorization()`
88117
@@ -650,6 +679,48 @@ struct MetalLUFactorization <: AbstractFactorization
650679
end
651680
end
652681

682+
"""
683+
MetalOffload32MixedLUFactorization()
684+
685+
A mixed precision Metal GPU-accelerated LU factorization that converts matrices to Float32
686+
before offloading to Metal GPU for factorization, then converts back for the solve.
687+
This can provide speedups on Apple Silicon when reduced precision is acceptable.
688+
689+
## Performance Notes
690+
- Converts Float64 matrices to Float32 for GPU factorization
691+
- Can be significantly faster for large matrices where memory bandwidth is limiting
692+
- Particularly effective on Apple Silicon Macs with unified memory architecture
693+
- May have reduced accuracy compared to full precision methods
694+
695+
## Requirements
696+
Using this solver requires that Metal.jl is loaded: `using Metal`
697+
698+
## Example
699+
```julia
700+
using Metal
701+
alg = MetalOffload32MixedLUFactorization()
702+
sol = solve(prob, alg)
703+
```
704+
"""
705+
struct MetalOffload32MixedLUFactorization <: AbstractFactorization
706+
function MetalOffload32MixedLUFactorization(; throwerror = true)
707+
@static if !Sys.isapple()
708+
if throwerror
709+
error("MetalOffload32MixedLUFactorization is only available on Apple platforms")
710+
else
711+
return new()
712+
end
713+
else
714+
ext = Base.get_extension(@__MODULE__, :LinearSolveMetalExt)
715+
if ext === nothing && throwerror
716+
error("MetalOffload32MixedLUFactorization requires that Metal.jl is loaded, i.e. `using Metal`")
717+
else
718+
return new()
719+
end
720+
end
721+
end
722+
end
723+
653724
"""
654725
BLISLUFactorization()
655726
@@ -715,3 +786,51 @@ struct CUSOLVERRFFactorization <: AbstractSparseFactorization
715786
end
716787
end
717788
end
789+
790+
"""
791+
MKL32MixedLUFactorization()
792+
793+
A mixed precision LU factorization using Intel MKL that performs factorization in Float32
794+
precision while maintaining Float64 interface. This can provide significant speedups
795+
for large matrices when reduced precision is acceptable.
796+
797+
## Performance Notes
798+
- Converts Float64 matrices to Float32 for factorization
799+
- Uses optimized MKL routines for the factorization
800+
- Can be 2x faster than full precision for memory-bandwidth limited problems
801+
- May have reduced accuracy compared to full Float64 precision
802+
803+
## Requirements
804+
This solver requires MKL to be available through MKL_jll.
805+
806+
## Example
807+
```julia
808+
alg = MKL32MixedLUFactorization()
809+
sol = solve(prob, alg)
810+
```
811+
"""
812+
struct MKL32MixedLUFactorization <: AbstractFactorization end
813+
814+
"""
815+
AppleAccelerate32MixedLUFactorization()
816+
817+
A mixed precision LU factorization using Apple's Accelerate framework that performs
818+
factorization in Float32 precision while maintaining Float64 interface. This can
819+
provide significant speedups on Apple hardware when reduced precision is acceptable.
820+
821+
## Performance Notes
822+
- Converts Float64 matrices to Float32 for factorization
823+
- Uses optimized Accelerate routines for the factorization
824+
- Particularly effective on Apple Silicon with unified memory
825+
- May have reduced accuracy compared to full Float64 precision
826+
827+
## Requirements
828+
This solver is only available on Apple platforms and requires the Accelerate framework.
829+
830+
## Example
831+
```julia
832+
alg = AppleAccelerate32MixedLUFactorization()
833+
sol = solve(prob, alg)
834+
```
835+
"""
836+
struct AppleAccelerate32MixedLUFactorization <: AbstractFactorization end

0 commit comments

Comments
 (0)