Skip to content

Commit 6022080

Browse files
committed
in-place
1 parent 8a4f192 commit 6022080

File tree

4 files changed

+23
-12
lines changed

4 files changed

+23
-12
lines changed

src/common.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
struct LinearCache{TA,Tb,Tp,Talg,Tc,Tr,Tl}
1+
struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tr,Tl}
22
A::TA
33
b::Tb
4+
u::Tu
45
p::Tp
56
alg::Talg
67
cacheval::Tc
78
isfresh::Bool
89
Pr::Tr
910
Pl::Tl
11+
# k::Tk # iteration count
1012
end
1113

1214
function set_A(cache, A)
@@ -39,23 +41,28 @@ function SciMLBase.init(
3941
alias_b = false,
4042
kwargs...,
4143
)
42-
@unpack A, b, p = prob
44+
@unpack A, b, u0, p = prob
4345
if alg isa LUFactorization
4446
fact = lu_instance(A)
4547
Tfact = typeof(fact)
4648
else
4749
fact = nothing
4850
Tfact = Any
4951
end
50-
Pr = nothing
51-
Pl = nothing
52+
Pr = LinearAlgebra.I
53+
Pl = LinearAlgebra.I
5254

5355
A = alias_A ? A : copy(A)
5456
b = alias_b ? b : copy(b)
5557

58+
if u0 == nothing
59+
u0 = zero(b)
60+
end
61+
5662
cache = LinearCache{
5763
typeof(A),
5864
typeof(b),
65+
typeof(u0),
5966
typeof(p),
6067
typeof(alg),
6168
Tfact,
@@ -64,6 +71,7 @@ function SciMLBase.init(
6471
}(
6572
A,
6673
b,
74+
u0,
6775
p,
6876
alg,
6977
fact,

src/factorization.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function SciMLBase.solve(cache::LinearCache, alg::LUFactorization)
1616
cache.A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
1717
error("LU is not defined for $(typeof(prob.A))")
1818
cache = set_cacheval(cache, lu!(cache.A, alg.pivot))
19-
ldiv!(cache.cacheval, cache.b)
19+
ldiv!(cache.u,cache.cacheval, cache.b)
2020
end
2121

2222
struct QRFactorization{P} <: SciMLLinearSolveAlgorithm
@@ -40,7 +40,7 @@ function SciMLBase.solve(cache::LinearCache, alg::QRFactorization)
4040
cache,
4141
qr!(cache.A.A, alg.pivot; blocksize = alg.blocksize),
4242
)
43-
ldiv!(cache.cacheval, cache.b)
43+
ldiv!(cache.u,cache.cacheval, cache.b)
4444
end
4545

4646
struct SVDFactorization{A} <: SciMLLinearSolveAlgorithm
@@ -54,5 +54,5 @@ function SciMLBase.solve(cache::LinearCache, alg::SVDFactorization)
5454
cache.A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
5555
error("SVD is not defined for $(typeof(prob.A))")
5656
cache = set_cacheval(cache, svd!(cache.A; full = alg.full, alg = alg.alg))
57-
ldiv!(cache.cacheval, cache.b)
57+
ldiv!(cache.u,cache.cacheval, cache.b)
5858
end

src/krylov.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ function KrylovJL(args...; solver = gmres, kwargs...)
99
end
1010

1111
function SciMLBase.solve(cache::LinearCache, alg::KrylovJL,args...;kwargs...)
12-
@unpack A, b, Pl,Pr = cache
13-
x, stats = alg.solver(A, b, args...; M=Pl, N=Pr, kwargs...)
14-
resid = A * x - b
12+
@unpack A, b, u, Pr, Pl = cache
13+
u, stats = alg.solver(A, b, args...; M=Pl, N=Pr, kwargs...)
14+
resid = A * u - b
1515
retcode = stats.solved ? :Success : :Failure
16-
return x #SciMLBase.build_solution(prob, alg, x, resid; retcode = retcode)
16+
return u #SciMLBase.build_solution(prob, alg, x, resid; retcode = retcode)
1717
end

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,8 @@ using Test
2020
# make algorithm callable - interoperable with DiffEq ecosystem
2121
@test A * LUFactorization()(x,A,b) b
2222
@test A * KrylovJL()(x,A,b) b
23-
@test_broken A * x b # in place
23+
24+
# in place
25+
KrylovJL()(x,A,b)
26+
@test A * x b
2427
end

0 commit comments

Comments
 (0)