Skip to content

Commit 3fdb37a

Browse files
Match RF32MixedLUFactorization pivoting with RFLUFactorization
- Store (fact, ipiv) tuple in cache exactly like RFLUFactorization - Pass ipiv to RecursiveFactorization.lu! and store both fact and ipiv - Retrieve factorization using @get_cacheval()[1] pattern - This ensures consistent behavior between the two implementations
1 parent 09603e7 commit 3fdb37a

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

ext/LinearSolveRecursiveFactorizationExt.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ function LinearSolve.init_cacheval(alg::RF32MixedLUFactorization{P, T}, A, b, u,
5151
A_32 = rand(Float32, 0, 0)
5252
end
5353
luinst = ArrayInterface.lu_instance(A_32)
54-
luinst
54+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
55+
(luinst, ipiv)
5556
end
5657

5758
function SciMLBase.solve!(
@@ -63,6 +64,7 @@ function SciMLBase.solve!(
6364
# Check if we have complex numbers
6465
iscomplex = eltype(A) <: Complex
6566

67+
fact, ipiv = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
6668
if cache.isfresh
6769
# Convert to appropriate 32-bit type for factorization
6870
if iscomplex
@@ -71,11 +73,13 @@ function SciMLBase.solve!(
7173
A_f32 = Float32.(A)
7274
end
7375

74-
# Allocate pivot vector for the factorization
75-
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A_f32)...))
76+
# Ensure ipiv is the right size
77+
if length(ipiv) != min(size(A_f32)...)
78+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A_f32)...))
79+
end
7680

7781
fact = RecursiveFactorization.lu!(A_f32, ipiv, Val(P), Val(T), check = false)
78-
cache.cacheval = fact
82+
cache.cacheval = (fact, ipiv)
7983

8084
if !LinearAlgebra.issuccess(fact)
8185
return SciMLBase.build_linear_solution(
@@ -85,8 +89,9 @@ function SciMLBase.solve!(
8589
cache.isfresh = false
8690
end
8791

88-
fact = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
89-
92+
# Get the factorization from the cache
93+
fact_cached = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)[1]
94+
9095
# Convert b to appropriate 32-bit type for solving
9196
if iscomplex
9297
b_f32 = ComplexF32.(cache.b)
@@ -97,7 +102,7 @@ function SciMLBase.solve!(
97102
end
98103

99104
# Solve in 32-bit precision
100-
ldiv!(u_f32, fact, b_f32)
105+
ldiv!(u_f32, fact_cached, b_f32)
101106

102107
# Convert back to original precision
103108
T_orig = eltype(cache.u)

0 commit comments

Comments
 (0)