Skip to content

Commit 872681b

Browse files
Fix RF32MixedLUFactorization segfault issue
- Simplified cache initialization to only store the LU factorization object - RecursiveFactorization.lu! returns an LU object that contains its own pivot vector - Fixed improper pivot vector handling that was causing segfaults
1 parent 568360e commit 872681b

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

ext/LinearSolveRecursiveFactorizationExt.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function LinearSolve.init_cacheval(alg::RF32MixedLUFactorization{P, T}, A, b, u,
5151
A_32 = rand(Float32, 0, 0)
5252
end
5353
luinst = ArrayInterface.lu_instance(A_32)
54-
(luinst, Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)))
54+
luinst
5555
end
5656

5757
function SciMLBase.solve!(
@@ -64,22 +64,18 @@ function SciMLBase.solve!(
6464
iscomplex = eltype(A) <: Complex
6565

6666
if cache.isfresh
67-
fact, ipiv = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
68-
6967
# Convert to appropriate 32-bit type for factorization
7068
if iscomplex
7169
A_f32 = ComplexF32.(A)
7270
else
7371
A_f32 = Float32.(A)
7472
end
7573

76-
# Ensure ipiv is the right size
77-
if length(ipiv) != min(size(A_f32)...)
78-
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A_f32)...))
79-
end
74+
# Allocate pivot vector for the factorization
75+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A_f32)...))
8076

8177
fact = RecursiveFactorization.lu!(A_f32, ipiv, Val(P), Val(T), check = false)
82-
cache.cacheval = (fact, ipiv)
78+
cache.cacheval = fact
8379

8480
if !LinearAlgebra.issuccess(fact)
8581
return SciMLBase.build_linear_solution(
@@ -89,17 +85,18 @@ function SciMLBase.solve!(
8985
cache.isfresh = false
9086
end
9187

92-
fact, ipiv = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
88+
fact = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
9389

9490
# Convert b to appropriate 32-bit type for solving
9591
if iscomplex
9692
b_f32 = ComplexF32.(cache.b)
93+
u_f32 = similar(b_f32)
9794
else
9895
b_f32 = Float32.(cache.b)
96+
u_f32 = similar(b_f32)
9997
end
10098

10199
# Solve in 32-bit precision
102-
u_f32 = similar(b_f32)
103100
ldiv!(u_f32, fact, b_f32)
104101

105102
# Convert back to original precision

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[deps]

0 commit comments

Comments
 (0)