Skip to content

Commit b02eb70

Browse files
committed
save ipiv for rf lu
1 parent b1cf3a1 commit b02eb70

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

src/factorization.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function SciMLBase.solve(cache::LinearCache, alg::AbstractFactorization; kwargs.
1414
SciMLBase.build_linear_solution(alg, y, nothing, cache)
1515
end
1616

17-
# Bad fallback: will fail if `A` is just a stand-in
17+
#RF Bad fallback: will fail if `A` is just a stand-in
1818
# This should instead just create the factorization type.
1919
function init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters, abstol,
2020
reltol, verbose)
@@ -328,25 +328,34 @@ end
328328

329329
## RFLUFactorization
330330

331-
struct RFWrapper{P, T}
332-
RFWrapper(::Val{P}, ::Val{T}) where {P, T} = new{P, T}()
331+
332+
struct RFLUFactorization{P, T}
333+
RFLUFactorization(::Val{P}, ::Val{T}) where {P, T} = new{P, T}()
333334
end
334-
(::RFWrapper{P, T})(A) where {P, T} = RecursiveFactorization.lu!(A, Val(P), Val(T))
335335

336336
function RFLUFactorization(; pivot = Val(true), thread = Val(true))
337-
GenericFactorization(; fact_alg = RFWrapper(pivot, thread))
337+
RFLUFactorization(pivot, thread)
338338
end
339339

340-
function init_cacheval(alg::GenericFactorization{<:RFWrapper}, A, b, u, Pl, Pr, maxiters,
340+
function init_cacheval(alg::RFLUFactorization, A, b, u, Pl, Pr, maxiters,
341341
abstol, reltol, verbose)
342-
ArrayInterfaceCore.lu_instance(convert(AbstractMatrix, A))
342+
ipiv = Vector{BlasInt}(undef, min(size(A)...));
343+
ArrayInterfaceCore.lu_instance(convert(AbstractMatrix, A)), ipiv
343344
end
344-
function init_cacheval(alg::GenericFactorization{<:RFWrapper},
345-
A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters,
346-
abstol, reltol, verbose)
347-
ArrayInterfaceCore.lu_instance(convert(AbstractMatrix, A))
345+
346+
function SciMLBase.solve(cache::LinearCache, alg::RFLUFactorization{P,T}) where {P,T}
347+
A = cache.A
348+
A = convert(AbstractMatrix, A)
349+
fact, ipiv = cache.cacheval
350+
if cache.isfresh
351+
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T))
352+
cache = set_cacheval(cache, (fact, ipiv))
353+
end
354+
y = ldiv!(cache.u, cache.cacheval[1], cache.b)
355+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
348356
end
349357

358+
350359
## FastLAPACKFactorizations
351360

352361
struct WorkspaceAndFactors{W, F}

0 commit comments

Comments
 (0)