Skip to content

Commit a07ee0b

Browse files
ChrisRackauckas-ClaudeclaudeChrisRackauckas
authored
Add 32-bit mixed precision solvers for OpenBLAS and RecursiveFactorization (#753)
* Add 32-bit mixed precision solvers for OpenBLAS and RecursiveFactorization Adds two new mixed precision LU factorization algorithms that perform factorization in Float32 precision while maintaining Float64 interface for improved performance: - OpenBLAS32MixedLUFactorization: Mixed precision solver using OpenBLAS - RF32MixedLUFactorization: Mixed precision solver using RecursiveFactorization.jl These solvers follow the same pattern as the existing MKL32MixedLUFactorization and AppleAccelerate32MixedLUFactorization implementations, providing: - ~2x speedup for memory-bandwidth limited problems - Support for both real and complex matrices - Automatic precision conversion and management - Comprehensive test coverage The RF32MixedLUFactorization also supports pivoting options for trading stability vs performance. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Fix resolve.jl tests for mixed precision solvers - Add higher tolerance for mixed precision algorithms (atol=1e-5, rtol=1e-5) - Skip tests for algorithms that require unavailable packages - Add proper checks for RF32MixedLUFactorization and OpenBLAS32MixedLUFactorization The mixed precision algorithms naturally have lower accuracy than full precision, so they need relaxed tolerances in the tests. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Add RecursiveFactorization to Project.toml for tests * Apply formatting and fix additional test compatibility - Format code with JuliaFormatter SciMLStyle - Update resolve.jl tests to properly handle mixed precision algorithms - Add appropriate tolerance checks for Float32 precision solvers 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Move RecursiveFactorization back to weakdeps RecursiveFactorization should remain as a weak dependency since it's optional and loaded via an extension. * Increase tolerance for mixed precision tests in resolve.jl Mixed precision algorithms need higher tolerance due to reduced precision arithmetic. Increased from atol=1e-5, rtol=1e-5 to atol=1e-4, rtol=1e-4. * Fix mixed precision detection in resolve.jl tests Use string matching to detect mixed precision algorithms instead of symbol comparison. This ensures the tolerance branch is properly taken for algorithms like RF32MixedLUFactorization. * Fix RF32MixedLUFactorization segfault issue - Simplified cache initialization to only store the LU factorization object - RecursiveFactorization.lu! returns an LU object that contains its own pivot vector - Fixed improper pivot vector handling that was causing segfaults * Delete test/Project.toml * Match RF32MixedLUFactorization pivoting with RFLUFactorization - Store (fact, ipiv) tuple in cache exactly like RFLUFactorization - Pass ipiv to RecursiveFactorization.lu! and store both fact and ipiv - Retrieve factorization using @get_cacheval()[1] pattern - This ensures consistent behavior between the two implementations * fix rebase * Don't test no-pivot RFLU --------- Co-authored-by: Claude <[email protected]> Co-authored-by: ChrisRackauckas <[email protected]>
1 parent a2af566 commit a07ee0b

File tree

8 files changed

+457
-104
lines changed

8 files changed

+457
-104
lines changed

ext/LinearSolveRecursiveFactorizationExt.jl

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module LinearSolveRecursiveFactorizationExt
22

33
using LinearSolve: LinearSolve, userecursivefactorization, LinearCache, @get_cacheval,
4-
RFLUFactorization
4+
RFLUFactorization, RF32MixedLUFactorization, default_alias_A,
5+
default_alias_b
56
using LinearSolve.LinearAlgebra, LinearSolve.ArrayInterface, RecursiveFactorization
67
using SciMLBase: SciMLBase, ReturnCode
78

@@ -30,4 +31,85 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization
3031
SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success)
3132
end
3233

34+
# Mixed precision RecursiveFactorization implementation
35+
LinearSolve.default_alias_A(::RF32MixedLUFactorization, ::Any, ::Any) = false
36+
LinearSolve.default_alias_b(::RF32MixedLUFactorization, ::Any, ::Any) = false
37+
38+
const PREALLOCATED_RF32_LU = begin
39+
A = rand(Float32, 0, 0)
40+
luinst = ArrayInterface.lu_instance(A)
41+
(luinst, Vector{LinearAlgebra.BlasInt}(undef, 0))
42+
end
43+
44+
function LinearSolve.init_cacheval(alg::RF32MixedLUFactorization{P, T}, A, b, u, Pl, Pr,
45+
maxiters::Int, abstol, reltol, verbose::Bool,
46+
assumptions::LinearSolve.OperatorAssumptions) where {P, T}
47+
# Pre-allocate appropriate 32-bit arrays based on input type
48+
if eltype(A) <: Complex
49+
A_32 = rand(ComplexF32, 0, 0)
50+
else
51+
A_32 = rand(Float32, 0, 0)
52+
end
53+
luinst = ArrayInterface.lu_instance(A_32)
54+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
55+
(luinst, ipiv)
56+
end
57+
58+
function SciMLBase.solve!(
59+
cache::LinearSolve.LinearCache, alg::RF32MixedLUFactorization{P, T};
60+
kwargs...) where {P, T}
61+
A = cache.A
62+
A = convert(AbstractMatrix, A)
63+
64+
# Check if we have complex numbers
65+
iscomplex = eltype(A) <: Complex
66+
67+
fact, ipiv = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
68+
if cache.isfresh
69+
# Convert to appropriate 32-bit type for factorization
70+
if iscomplex
71+
A_f32 = ComplexF32.(A)
72+
else
73+
A_f32 = Float32.(A)
74+
end
75+
76+
# Ensure ipiv is the right size
77+
if length(ipiv) != min(size(A_f32)...)
78+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A_f32)...))
79+
end
80+
81+
fact = RecursiveFactorization.lu!(A_f32, ipiv, Val(P), Val(T), check = false)
82+
cache.cacheval = (fact, ipiv)
83+
84+
if !LinearAlgebra.issuccess(fact)
85+
return SciMLBase.build_linear_solution(
86+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
87+
end
88+
89+
cache.isfresh = false
90+
end
91+
92+
# Get the factorization from the cache
93+
fact_cached = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)[1]
94+
95+
# Convert b to appropriate 32-bit type for solving
96+
if iscomplex
97+
b_f32 = ComplexF32.(cache.b)
98+
u_f32 = similar(b_f32)
99+
else
100+
b_f32 = Float32.(cache.b)
101+
u_f32 = similar(b_f32)
102+
end
103+
104+
# Solve in 32-bit precision
105+
ldiv!(u_f32, fact_cached, b_f32)
106+
107+
# Convert back to original precision
108+
T_orig = eltype(cache.u)
109+
cache.u .= T_orig.(u_f32)
110+
111+
SciMLBase.build_linear_solution(
112+
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
113+
end
114+
33115
end

src/LinearSolve.jl

Lines changed: 74 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,13 @@ else
6666
const useopenblas = false
6767
end
6868

69-
7069
@reexport using SciMLBase
7170

7271
"""
7372
SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm
7473
7574
The root abstract type for all linear solver algorithms in LinearSolve.jl.
76-
All concrete linear solver implementations should inherit from one of the
75+
All concrete linear solver implementations should inherit from one of the
7776
specialized subtypes rather than directly from this type.
7877
7978
This type integrates with the SciMLBase ecosystem, providing a consistent
@@ -91,39 +90,44 @@ matrices (e.g., `A = LU`, `A = QR`, `A = LDL'`) and then solve the system
9190
using forward/backward substitution.
9291
9392
## Characteristics
94-
- Requires concrete matrix representation (`needs_concrete_A() = true`)
95-
- Typically efficient for multiple solves with the same matrix
96-
- Generally provides high accuracy for well-conditioned problems
97-
- Memory requirements depend on the specific factorization type
93+
94+
- Requires concrete matrix representation (`needs_concrete_A() = true`)
95+
- Typically efficient for multiple solves with the same matrix
96+
- Generally provides high accuracy for well-conditioned problems
97+
- Memory requirements depend on the specific factorization type
9898
9999
## Subtypes
100-
- `AbstractDenseFactorization`: For dense matrix factorizations
101-
- `AbstractSparseFactorization`: For sparse matrix factorizations
100+
101+
- `AbstractDenseFactorization`: For dense matrix factorizations
102+
- `AbstractSparseFactorization`: For sparse matrix factorizations
102103
103104
## Examples of concrete subtypes
104-
- `LUFactorization`, `QRFactorization`, `CholeskyFactorization`
105-
- `UMFPACKFactorization`, `KLUFactorization`
105+
106+
- `LUFactorization`, `QRFactorization`, `CholeskyFactorization`
107+
- `UMFPACKFactorization`, `KLUFactorization`
106108
"""
107109
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
108110

109111
"""
110112
AbstractSparseFactorization <: AbstractFactorization
111113
112114
Abstract type for factorization-based linear solvers optimized for sparse matrices.
113-
These algorithms take advantage of sparsity patterns to reduce memory usage and
115+
These algorithms take advantage of sparsity patterns to reduce memory usage and
114116
computational cost compared to dense factorizations.
115117
116-
## Characteristics
117-
- Optimized for matrices with many zero entries
118-
- Often use specialized pivoting strategies to preserve sparsity
119-
- May reorder rows/columns to minimize fill-in during factorization
120-
- Typically more memory-efficient than dense methods for sparse problems
118+
## Characteristics
119+
120+
- Optimized for matrices with many zero entries
121+
- Often use specialized pivoting strategies to preserve sparsity
122+
- May reorder rows/columns to minimize fill-in during factorization
123+
- Typically more memory-efficient than dense methods for sparse problems
121124
122125
## Examples of concrete subtypes
123-
- `UMFPACKFactorization`: General sparse LU with partial pivoting
124-
- `KLUFactorization`: Sparse LU optimized for circuit simulation
125-
- `CHOLMODFactorization`: Sparse Cholesky for positive definite systems
126-
- `SparspakFactorization`: Envelope/profile method for sparse systems
126+
127+
- `UMFPACKFactorization`: General sparse LU with partial pivoting
128+
- `KLUFactorization`: Sparse LU optimized for circuit simulation
129+
- `CHOLMODFactorization`: Sparse Cholesky for positive definite systems
130+
- `SparspakFactorization`: Envelope/profile method for sparse systems
127131
"""
128132
abstract type AbstractSparseFactorization <: AbstractFactorization end
129133

@@ -135,16 +139,18 @@ These algorithms assume the matrix has no particular sparsity structure and use
135139
dense linear algebra routines (typically from BLAS/LAPACK) for optimal performance.
136140
137141
## Characteristics
138-
- Optimized for matrices with few zeros or no sparsity structure
139-
- Leverage highly optimized BLAS/LAPACK routines when available
140-
- Generally provide excellent performance for moderately-sized dense problems
141-
- Memory usage scales as O(n²) with matrix size
142-
143-
## Examples of concrete subtypes
144-
- `LUFactorization`: Dense LU with partial pivoting (via LAPACK)
145-
- `QRFactorization`: Dense QR factorization for overdetermined systems
146-
- `CholeskyFactorization`: Dense Cholesky for symmetric positive definite matrices
147-
- `BunchKaufmanFactorization`: For symmetric indefinite matrices
142+
143+
- Optimized for matrices with few zeros or no sparsity structure
144+
- Leverage highly optimized BLAS/LAPACK routines when available
145+
- Generally provide excellent performance for moderately-sized dense problems
146+
- Memory usage scales as O(n²) with matrix size
147+
148+
## Examples of concrete subtypes
149+
150+
- `LUFactorization`: Dense LU with partial pivoting (via LAPACK)
151+
- `QRFactorization`: Dense QR factorization for overdetermined systems
152+
- `CholeskyFactorization`: Dense Cholesky for symmetric positive definite matrices
153+
- `BunchKaufmanFactorization`: For symmetric indefinite matrices
148154
"""
149155
abstract type AbstractDenseFactorization <: AbstractFactorization end
150156

@@ -156,23 +162,26 @@ These algorithms solve linear systems by iteratively building an approximation
156162
from a sequence of Krylov subspaces, without requiring explicit matrix factorization.
157163
158164
## Characteristics
159-
- Does not require concrete matrix representation (`needs_concrete_A() = false`)
160-
- Only needs matrix-vector products `A*v` (can work with operators/functions)
161-
- Memory usage typically O(n) or O(kn) where k << n
162-
- Convergence depends on matrix properties (condition number, eigenvalue distribution)
163-
- Often benefits significantly from preconditioning
165+
166+
- Does not require concrete matrix representation (`needs_concrete_A() = false`)
167+
- Only needs matrix-vector products `A*v` (can work with operators/functions)
168+
- Memory usage typically O(n) or O(kn) where k << n
169+
- Convergence depends on matrix properties (condition number, eigenvalue distribution)
170+
- Often benefits significantly from preconditioning
164171
165172
## Advantages
166-
- Low memory requirements for large sparse systems
167-
- Can handle matrix-free operators (functions that compute `A*v`)
168-
- Often the only feasible approach for very large systems
169-
- Can exploit matrix structure through specialized operators
173+
174+
- Low memory requirements for large sparse systems
175+
- Can handle matrix-free operators (functions that compute `A*v`)
176+
- Often the only feasible approach for very large systems
177+
- Can exploit matrix structure through specialized operators
170178
171179
## Examples of concrete subtypes
172-
- `GMRESIteration`: Generalized Minimal Residual method
173-
- `CGIteration`: Conjugate Gradient (for symmetric positive definite systems)
174-
- `BiCGStabLIteration`: Bi-Conjugate Gradient Stabilized
175-
- Wrapped external iterative solvers (KrylovKit.jl, IterativeSolvers.jl)
180+
181+
- `GMRESIteration`: Generalized Minimal Residual method
182+
- `CGIteration`: Conjugate Gradient (for symmetric positive definite systems)
183+
- `BiCGStabLIteration`: Bi-Conjugate Gradient Stabilized
184+
- Wrapped external iterative solvers (KrylovKit.jl, IterativeSolvers.jl)
176185
"""
177186
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
178187

@@ -183,15 +192,17 @@ Abstract type for linear solvers that wrap custom solving functions or
183192
provide direct interfaces to specific solve methods. These provide flexibility
184193
for integrating custom algorithms or simple solve strategies.
185194
186-
## Characteristics
187-
- Does not require concrete matrix representation (`needs_concrete_A() = false`)
188-
- Provides maximum flexibility for custom solving strategies
189-
- Can wrap external solver libraries or implement specialized algorithms
190-
- Performance and stability depend entirely on the wrapped implementation
195+
## Characteristics
196+
197+
- Does not require concrete matrix representation (`needs_concrete_A() = false`)
198+
- Provides maximum flexibility for custom solving strategies
199+
- Can wrap external solver libraries or implement specialized algorithms
200+
- Performance and stability depend entirely on the wrapped implementation
191201
192202
## Examples of concrete subtypes
193-
- `LinearSolveFunction`: Wraps arbitrary user-defined solve functions
194-
- `DirectLdiv!`: Direct application of the `\\` operator
203+
204+
- `LinearSolveFunction`: Wraps arbitrary user-defined solve functions
205+
- `DirectLdiv!`: Direct application of the `\\` operator
195206
"""
196207
abstract type AbstractSolveFunction <: SciMLLinearSolveAlgorithm end
197208

@@ -204,22 +215,27 @@ Trait function that determines whether a linear solver algorithm requires
204215
a concrete matrix representation or can work with abstract operators.
205216
206217
## Arguments
207-
- `alg`: A linear solver algorithm instance
218+
219+
- `alg`: A linear solver algorithm instance
208220
209221
## Returns
210-
- `true`: Algorithm requires a concrete matrix (e.g., for factorization)
211-
- `false`: Algorithm can work with abstract operators (e.g., matrix-free methods)
222+
223+
- `true`: Algorithm requires a concrete matrix (e.g., for factorization)
224+
- `false`: Algorithm can work with abstract operators (e.g., matrix-free methods)
212225
213226
## Usage
227+
214228
This trait is used internally by LinearSolve.jl to optimize algorithm dispatch
215229
and determine when matrix operators need to be converted to concrete arrays.
216230
217231
## Algorithm-Specific Behavior
218-
- `AbstractFactorization`: `true` (needs explicit matrix entries for factorization)
219-
- `AbstractKrylovSubspaceMethod`: `false` (only needs matrix-vector products)
220-
- `AbstractSolveFunction`: `false` (depends on the wrapped function's requirements)
232+
233+
- `AbstractFactorization`: `true` (needs explicit matrix entries for factorization)
234+
- `AbstractKrylovSubspaceMethod`: `false` (only needs matrix-vector products)
235+
- `AbstractSolveFunction`: `false` (depends on the wrapped function's requirements)
221236
222237
## Example
238+
223239
```julia
224240
needs_concrete_A(LUFactorization()) # true
225241
needs_concrete_A(GMRESIteration()) # false
@@ -470,9 +486,11 @@ export PanuaPardisoFactorize, PanuaPardisoIterate
470486
export PardisoJL
471487
export MKLLUFactorization
472488
export OpenBLASLUFactorization
489+
export OpenBLAS32MixedLUFactorization
473490
export MKL32MixedLUFactorization
474491
export AppleAccelerateLUFactorization
475492
export AppleAccelerate32MixedLUFactorization
493+
export RF32MixedLUFactorization
476494
export MetalLUFactorization
477495
export MetalOffload32MixedLUFactorization
478496

0 commit comments

Comments
 (0)