Skip to content

Commit 7b29b4b

Browse files
committed
Add 32-bit mixed precision solvers for OpenBLAS and RecursiveFactorization
Adds two new mixed precision LU factorization algorithms that perform factorization in Float32 precision while maintaining Float64 interface for improved performance: - OpenBLAS32MixedLUFactorization: Mixed precision solver using OpenBLAS - RF32MixedLUFactorization: Mixed precision solver using RecursiveFactorization.jl These solvers follow the same pattern as the existing MKL32MixedLUFactorization and AppleAccelerate32MixedLUFactorization implementations, providing: - ~2x speedup for memory-bandwidth limited problems - Support for both real and complex matrices - Automatic precision conversion and management - Comprehensive test coverage The RF32MixedLUFactorization also supports pivoting options for trading stability vs performance. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent b77b167 commit 7b29b4b

File tree

5 files changed

+308
-11
lines changed

5 files changed

+308
-11
lines changed

ext/LinearSolveRecursiveFactorizationExt.jl

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module LinearSolveRecursiveFactorizationExt
22

33
using LinearSolve: LinearSolve, userecursivefactorization, LinearCache, @get_cacheval,
4-
RFLUFactorization
4+
RFLUFactorization, RF32MixedLUFactorization, default_alias_A,
5+
default_alias_b
56
using LinearSolve.LinearAlgebra, LinearSolve.ArrayInterface, RecursiveFactorization
67
using SciMLBase: SciMLBase, ReturnCode
78

@@ -30,4 +31,83 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization
3031
SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success)
3132
end
3233

34+
# Mixed precision RecursiveFactorization implementation
35+
LinearSolve.default_alias_A(::RF32MixedLUFactorization, ::Any, ::Any) = false
36+
LinearSolve.default_alias_b(::RF32MixedLUFactorization, ::Any, ::Any) = false
37+
38+
const PREALLOCATED_RF32_LU = begin
39+
A = rand(Float32, 0, 0)
40+
luinst = ArrayInterface.lu_instance(A)
41+
(luinst, Vector{LinearAlgebra.BlasInt}(undef, 0))
42+
end
43+
44+
function LinearSolve.init_cacheval(alg::RF32MixedLUFactorization{P, T}, A, b, u, Pl, Pr,
45+
maxiters::Int, abstol, reltol, verbose::LinearSolve.LinearVerbosity,
46+
assumptions::LinearSolve.OperatorAssumptions) where {P, T}
47+
# Pre-allocate appropriate 32-bit arrays based on input type
48+
if eltype(A) <: Complex
49+
A_32 = rand(ComplexF32, 0, 0)
50+
else
51+
A_32 = rand(Float32, 0, 0)
52+
end
53+
luinst = ArrayInterface.lu_instance(A_32)
54+
(luinst, Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)))
55+
end
56+
57+
function SciMLBase.solve!(
58+
cache::LinearSolve.LinearCache, alg::RF32MixedLUFactorization{P, T};
59+
kwargs...) where {P, T}
60+
A = cache.A
61+
A = convert(AbstractMatrix, A)
62+
63+
# Check if we have complex numbers
64+
iscomplex = eltype(A) <: Complex
65+
66+
if cache.isfresh
67+
fact, ipiv = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
68+
69+
# Convert to appropriate 32-bit type for factorization
70+
if iscomplex
71+
A_f32 = ComplexF32.(A)
72+
else
73+
A_f32 = Float32.(A)
74+
end
75+
76+
# Ensure ipiv is the right size
77+
if length(ipiv) != min(size(A_f32)...)
78+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A_f32)...))
79+
end
80+
81+
fact = RecursiveFactorization.lu!(A_f32, ipiv, Val(P), Val(T), check = false)
82+
cache.cacheval = (fact, ipiv)
83+
84+
if !LinearAlgebra.issuccess(fact)
85+
return SciMLBase.build_linear_solution(
86+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
87+
end
88+
89+
cache.isfresh = false
90+
end
91+
92+
fact, ipiv = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
93+
94+
# Convert b to appropriate 32-bit type for solving
95+
if iscomplex
96+
b_f32 = ComplexF32.(cache.b)
97+
else
98+
b_f32 = Float32.(cache.b)
99+
end
100+
101+
# Solve in 32-bit precision
102+
u_f32 = similar(b_f32)
103+
ldiv!(u_f32, fact, b_f32)
104+
105+
# Convert back to original precision
106+
T_orig = eltype(cache.u)
107+
cache.u .= T_orig.(u_f32)
108+
109+
SciMLBase.build_linear_solution(
110+
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
111+
end
112+
33113
end

src/LinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,11 @@ export PanuaPardisoFactorize, PanuaPardisoIterate
472472
export PardisoJL
473473
export MKLLUFactorization
474474
export OpenBLASLUFactorization
475+
export OpenBLAS32MixedLUFactorization
475476
export MKL32MixedLUFactorization
476477
export AppleAccelerateLUFactorization
477478
export AppleAccelerate32MixedLUFactorization
479+
export RF32MixedLUFactorization
478480
export MetalLUFactorization
479481
export MetalOffload32MixedLUFactorization
480482

src/extension_algs.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,3 +834,77 @@ sol = solve(prob, alg)
834834
```
835835
"""
836836
struct AppleAccelerate32MixedLUFactorization <: AbstractFactorization end
837+
838+
"""
839+
OpenBLAS32MixedLUFactorization()
840+
841+
A mixed precision LU factorization using OpenBLAS that performs factorization in Float32
842+
precision while maintaining Float64 interface. This can provide significant speedups
843+
for large matrices when reduced precision is acceptable.
844+
845+
## Performance Notes
846+
- Converts Float64 matrices to Float32 for factorization
847+
- Uses optimized OpenBLAS routines for the factorization
848+
- Can be 2x faster than full precision for memory-bandwidth limited problems
849+
- May have reduced accuracy compared to full Float64 precision
850+
851+
## Requirements
852+
This solver requires OpenBLAS to be available through OpenBLAS_jll.
853+
854+
## Example
855+
```julia
856+
alg = OpenBLAS32MixedLUFactorization()
857+
sol = solve(prob, alg)
858+
```
859+
"""
860+
struct OpenBLAS32MixedLUFactorization <: AbstractFactorization end
861+
862+
"""
863+
RF32MixedLUFactorization{P, T}(; pivot = Val(true), thread = Val(true))
864+
865+
A mixed precision LU factorization using RecursiveFactorization.jl that performs
866+
factorization in Float32 precision while maintaining Float64 interface. This combines
867+
the speed benefits of RecursiveFactorization.jl with reduced precision computation
868+
for additional performance gains.
869+
870+
## Type Parameters
871+
- `P`: Pivoting strategy as `Val{Bool}`. `Val{true}` enables partial pivoting for stability.
872+
- `T`: Threading strategy as `Val{Bool}`. `Val{true}` enables multi-threading for performance.
873+
874+
## Constructor Arguments
875+
- `pivot = Val(true)`: Enable partial pivoting. Set to `Val{false}` to disable for speed
876+
at the cost of numerical stability.
877+
- `thread = Val(true)`: Enable multi-threading. Set to `Val{false}` for single-threaded
878+
execution.
879+
880+
## Performance Notes
881+
- Converts Float64 matrices to Float32 for factorization
882+
- Leverages RecursiveFactorization.jl's optimized blocking strategies
883+
- Can provide significant speedups for small to medium matrices (< 500×500)
884+
- May have reduced accuracy compared to full Float64 precision
885+
886+
## Requirements
887+
Using this solver requires that RecursiveFactorization.jl is loaded: `using RecursiveFactorization`
888+
889+
## Example
890+
```julia
891+
using RecursiveFactorization
892+
# Fast mixed precision with pivoting
893+
alg1 = RF32MixedLUFactorization()
894+
# Fastest mixed precision (no pivoting), less stable
895+
alg2 = RF32MixedLUFactorization(pivot=Val(false))
896+
```
897+
"""
898+
struct RF32MixedLUFactorization{P, T} <: AbstractDenseFactorization
899+
function RF32MixedLUFactorization(::Val{P}, ::Val{T}; throwerror = true) where {P, T}
900+
if !userecursivefactorization(nothing)
901+
throwerror &&
902+
error("RF32MixedLUFactorization requires that RecursiveFactorization.jl is loaded, i.e. `using RecursiveFactorization`")
903+
end
904+
new{P, T}()
905+
end
906+
end
907+
908+
function RF32MixedLUFactorization(; pivot = Val(true), thread = Val(true), throwerror = true)
909+
RF32MixedLUFactorization(pivot, thread; throwerror)
910+
end

src/openblas.jl

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function openblas_getrf!(A::AbstractMatrix{<:ComplexF64};
4444
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
4545
info = Ref{BlasInt}(),
4646
check = false)
47-
__openblas_isavailable() ||
47+
__openblas_isavailable() ||
4848
error("Error, OpenBLAS binary is missing but solve is being called. Report this issue")
4949
require_one_based_indexing(A)
5050
check && chkfinite(A)
@@ -66,7 +66,7 @@ function openblas_getrf!(A::AbstractMatrix{<:ComplexF32};
6666
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
6767
info = Ref{BlasInt}(),
6868
check = false)
69-
__openblas_isavailable() ||
69+
__openblas_isavailable() ||
7070
error("Error, OpenBLAS binary is missing but solve is being called. Report this issue")
7171
require_one_based_indexing(A)
7272
check && chkfinite(A)
@@ -88,7 +88,7 @@ function openblas_getrf!(A::AbstractMatrix{<:Float64};
8888
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
8989
info = Ref{BlasInt}(),
9090
check = false)
91-
__openblas_isavailable() ||
91+
__openblas_isavailable() ||
9292
error("Error, OpenBLAS binary is missing but solve is being called. Report this issue")
9393
require_one_based_indexing(A)
9494
check && chkfinite(A)
@@ -110,7 +110,7 @@ function openblas_getrf!(A::AbstractMatrix{<:Float32};
110110
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
111111
info = Ref{BlasInt}(),
112112
check = false)
113-
__openblas_isavailable() ||
113+
__openblas_isavailable() ||
114114
error("Error, OpenBLAS binary is missing but solve is being called. Report this issue")
115115
require_one_based_indexing(A)
116116
check && chkfinite(A)
@@ -133,7 +133,7 @@ function openblas_getrs!(trans::AbstractChar,
133133
ipiv::AbstractVector{BlasInt},
134134
B::AbstractVecOrMat{<:ComplexF64};
135135
info = Ref{BlasInt}())
136-
__openblas_isavailable() ||
136+
__openblas_isavailable() ||
137137
error("Error, OpenBLAS binary is missing but solve is being called. Report this issue")
138138
require_one_based_indexing(A, ipiv, B)
139139
LinearAlgebra.LAPACK.chktrans(trans)
@@ -160,7 +160,7 @@ function openblas_getrs!(trans::AbstractChar,
160160
ipiv::AbstractVector{BlasInt},
161161
B::AbstractVecOrMat{<:ComplexF32};
162162
info = Ref{BlasInt}())
163-
__openblas_isavailable() ||
163+
__openblas_isavailable() ||
164164
error("Error, OpenBLAS binary is missing but solve is being called. Report this issue")
165165
require_one_based_indexing(A, ipiv, B)
166166
LinearAlgebra.LAPACK.chktrans(trans)
@@ -187,7 +187,7 @@ function openblas_getrs!(trans::AbstractChar,
187187
ipiv::AbstractVector{BlasInt},
188188
B::AbstractVecOrMat{<:Float64};
189189
info = Ref{BlasInt}())
190-
__openblas_isavailable() ||
190+
__openblas_isavailable() ||
191191
error("Error, OpenBLAS binary is missing but solve is being called. Report this issue")
192192
require_one_based_indexing(A, ipiv, B)
193193
LinearAlgebra.LAPACK.chktrans(trans)
@@ -214,7 +214,7 @@ function openblas_getrs!(trans::AbstractChar,
214214
ipiv::AbstractVector{BlasInt},
215215
B::AbstractVecOrMat{<:Float32};
216216
info = Ref{BlasInt}())
217-
__openblas_isavailable() ||
217+
__openblas_isavailable() ||
218218
error("Error, OpenBLAS binary is missing but solve is being called. Report this issue")
219219
require_one_based_indexing(A, ipiv, B)
220220
LinearAlgebra.LAPACK.chktrans(trans)
@@ -260,7 +260,7 @@ end
260260

261261
function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
262262
kwargs...)
263-
__openblas_isavailable() ||
263+
__openblas_isavailable() ||
264264
error("Error, OpenBLAS binary is missing but solve is being called. Report this issue")
265265
A = cache.A
266266
A = convert(AbstractMatrix, A)
@@ -292,3 +292,82 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
292292
SciMLBase.build_linear_solution(
293293
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
294294
end
295+
296+
# Mixed precision OpenBLAS implementation
297+
default_alias_A(::OpenBLAS32MixedLUFactorization, ::Any, ::Any) = false
298+
default_alias_b(::OpenBLAS32MixedLUFactorization, ::Any, ::Any) = false
299+
300+
const PREALLOCATED_OPENBLAS32_LU = begin
301+
A = rand(Float32, 0, 0)
302+
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
303+
end
304+
305+
function LinearSolve.init_cacheval(alg::OpenBLAS32MixedLUFactorization, A, b, u, Pl, Pr,
306+
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
307+
assumptions::OperatorAssumptions)
308+
# Pre-allocate appropriate 32-bit arrays based on input type
309+
if eltype(A) <: Complex
310+
A_32 = rand(ComplexF32, 0, 0)
311+
else
312+
A_32 = rand(Float32, 0, 0)
313+
end
314+
ArrayInterface.lu_instance(A_32), Ref{BlasInt}()
315+
end
316+
317+
function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorization;
318+
kwargs...)
319+
__openblas_isavailable() ||
320+
error("Error, OpenBLAS binary is missing but solve is being called. Report this issue")
321+
A = cache.A
322+
A = convert(AbstractMatrix, A)
323+
324+
# Check if we have complex numbers
325+
iscomplex = eltype(A) <: Complex
326+
327+
if cache.isfresh
328+
cacheval = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization)
329+
# Convert to appropriate 32-bit type for factorization
330+
if iscomplex
331+
A_f32 = ComplexF32.(A)
332+
else
333+
A_f32 = Float32.(A)
334+
end
335+
res = openblas_getrf!(A_f32; ipiv = cacheval[1].ipiv, info = cacheval[2])
336+
fact = LU(res[1:3]...), res[4]
337+
cache.cacheval = fact
338+
339+
if !LinearAlgebra.issuccess(fact[1])
340+
return SciMLBase.build_linear_solution(
341+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
342+
end
343+
cache.isfresh = false
344+
end
345+
346+
A_lu, info = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization)
347+
require_one_based_indexing(cache.u, cache.b)
348+
m, n = size(A_lu, 1), size(A_lu, 2)
349+
350+
# Convert b to appropriate 32-bit type for solving
351+
if iscomplex
352+
b_f32 = ComplexF32.(cache.b)
353+
else
354+
b_f32 = Float32.(cache.b)
355+
end
356+
357+
if m > n
358+
Bc = copy(b_f32)
359+
openblas_getrs!('N', A_lu.factors, A_lu.ipiv, Bc; info)
360+
# Convert back to original precision
361+
T = eltype(cache.u)
362+
cache.u .= T.(Bc[1:n])
363+
else
364+
u_f32 = copy(b_f32)
365+
openblas_getrs!('N', A_lu.factors, A_lu.ipiv, u_f32; info)
366+
# Convert back to original precision
367+
T = eltype(cache.u)
368+
cache.u .= T.(u_f32)
369+
end
370+
371+
SciMLBase.build_linear_solution(
372+
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
373+
end

0 commit comments

Comments
 (0)