Skip to content

Commit 68e726e

Browse files
Merge pull request #155 from chriselrod/saverfluipiv
save ipiv for rf lu
2 parents bd6a370 + eb5c23e commit 68e726e

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

src/factorization.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function SciMLBase.solve(cache::LinearCache, alg::AbstractFactorization; kwargs.
1414
SciMLBase.build_linear_solution(alg, y, nothing, cache)
1515
end
1616

17-
# Bad fallback: will fail if `A` is just a stand-in
17+
#RF Bad fallback: will fail if `A` is just a stand-in
1818
# This should instead just create the factorization type.
1919
function init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters, abstol,
2020
reltol, verbose)
@@ -328,23 +328,31 @@ end
328328

329329
## RFLUFactorization
330330

331-
struct RFWrapper{P, T}
332-
RFWrapper(::Val{P}, ::Val{T}) where {P, T} = new{P, T}()
331+
struct RFLUFactorization{P, T}
332+
RFLUFactorization(::Val{P}, ::Val{T}) where {P, T} = new{P, T}()
333333
end
334-
(::RFWrapper{P, T})(A) where {P, T} = RecursiveFactorization.lu!(A, Val(P), Val(T))
335334

336335
function RFLUFactorization(; pivot = Val(true), thread = Val(true))
337-
GenericFactorization(; fact_alg = RFWrapper(pivot, thread))
336+
RFLUFactorization(pivot, thread)
338337
end
339338

340-
function init_cacheval(alg::GenericFactorization{<:RFWrapper}, A, b, u, Pl, Pr, maxiters,
339+
function init_cacheval(alg::RFLUFactorization, A, b, u, Pl, Pr, maxiters,
341340
abstol, reltol, verbose)
342-
ArrayInterfaceCore.lu_instance(convert(AbstractMatrix, A))
341+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
342+
ArrayInterfaceCore.lu_instance(convert(AbstractMatrix, A)), ipiv
343343
end
344-
function init_cacheval(alg::GenericFactorization{<:RFWrapper},
345-
A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters,
346-
abstol, reltol, verbose)
347-
ArrayInterfaceCore.lu_instance(convert(AbstractMatrix, A))
344+
345+
function SciMLBase.solve(cache::LinearCache, alg::RFLUFactorization{P, T};
346+
kwargs...) where {P, T}
347+
A = cache.A
348+
A = convert(AbstractMatrix, A)
349+
fact, ipiv = cache.cacheval
350+
if cache.isfresh
351+
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T))
352+
cache = set_cacheval(cache, (fact, ipiv))
353+
end
354+
y = ldiv!(cache.u, cache.cacheval[1], cache.b)
355+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
348356
end
349357

350358
## FastLAPACKFactorizations

0 commit comments

Comments
 (0)