Skip to content

Commit 055dbf2

Browse files
Merge pull request #10 from SciML/first-branch
solve without !
2 parents df91d92 + 7a289f9 commit 055dbf2

File tree

3 files changed

+43
-24
lines changed

3 files changed

+43
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1616
ArrayInterface = "3"
1717
IterativeSolvers = "0.9"
1818
Reexport = "1"
19-
SciMLBase = "1"
19+
SciMLBase = "1.18.6"
2020
Setfield = "0.7"
2121
UnPack = "1"
2222
julia = "1"

src/LinearSolve.jl

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,17 @@ function set_p(cache, p)
3737
# @set! cache.isfresh = true
3838
end
3939

40-
function SciMLBase.init(prob::LinearProblem, alg; kwargs...)
40+
function set_cacheval(cache::LinearCache,alg)
41+
if cache.isfresh
42+
@set! cache.cacheval = alg
43+
@set! cache.isfresh = false
44+
end
45+
return cache
46+
end
47+
48+
function SciMLBase.init(prob::LinearProblem, alg;
49+
alias_A = false, alias_b = false,
50+
kwargs...)
4151
@unpack A, b, p = prob
4252
if alg isa LUFactorization
4353
fact = lu_instance(A)
@@ -48,41 +58,53 @@ function SciMLBase.init(prob::LinearProblem, alg; kwargs...)
4858
end
4959
Pr = nothing
5060
Pl = nothing
61+
62+
A = alias_A ? A : copy(A)
63+
b = alias_b ? b : copy(b)
64+
5165
cache = LinearCache{typeof(A),typeof(b),typeof(p),typeof(alg),Tfact,typeof(Pr),typeof(Pl)}(
5266
A, b, p, alg, fact, true, Pr, Pl
5367
)
5468
return cache
5569
end
5670

57-
SciMLBase.solve!(prob::LinearProblem, alg; kwargs...) = solve!(init(prob, alg; kwargs...))
58-
SciMLBase.solve!(cache) = solve!(cache, cache.alg)
71+
SciMLBase.solve(prob::LinearProblem, alg; kwargs...) = solve(init(prob, alg; kwargs...))
72+
SciMLBase.solve(cache) = solve(cache, cache.alg)
5973

6074
struct LUFactorization{P} <: AbstractLinearAlgorithm
6175
pivot::P
6276
end
63-
LUFactorization() = LUFactorization(Val(true))
77+
function LUFactorization()
78+
pivot = @static if VERSION < v"1.7beta"
79+
Val(true)
80+
else
81+
RowMaximum()
82+
end
83+
LUFactorization(pivot)
84+
end
6485

65-
function SciMLBase.solve!(cache::LinearCache, alg::LUFactorization)
86+
function SciMLBase.solve(cache::LinearCache, alg::LUFactorization)
6687
cache.A isa Union{AbstractMatrix, AbstractDiffEqOperator} || error("LU is not defined for $(typeof(prob.A))")
67-
if cache.isfresh
68-
@set! cache.cacheval = lu!(cache.A, alg.pivot)
69-
@set! cache.isfresh = false
70-
end
88+
cache = set_cacheval(cache,lu!(cache.A, alg.pivot))
7189
ldiv!(cache.cacheval, cache.b)
7290
end
7391

7492
struct QRFactorization{P} <: AbstractLinearAlgorithm
7593
pivot::P
7694
blocksize::Int
7795
end
78-
QRFactorization() = QRFactorization(Val(false), 16)
96+
function QRFactorization()
97+
pivot = @static if VERSION < v"1.7beta"
98+
Val(false)
99+
else
100+
NoPivot()
101+
end
102+
QRFactorization(pivot, 16)
103+
end
79104

80-
function SciMLBase.solve!(cache::LinearCache, alg::QRFactorization)
105+
function SciMLBase.solve(cache::LinearCache, alg::QRFactorization)
81106
cache.A isa Union{AbstractMatrix, AbstractDiffEqOperator} || error("QR is not defined for $(typeof(prob.A))")
82-
if cache.isfresh
83-
@set! cache.cacheval = qr!(cache.A.A, alg.pivot; blocksize=alg.blocksize)
84-
@set! cache.isfresh = false
85-
end
107+
cache = set_cacheval(cache,qr!(cache.A.A, alg.pivot; blocksize=alg.blocksize))
86108
ldiv!(cache.cacheval, cache.b)
87109
end
88110

@@ -92,12 +114,9 @@ struct SVDFactorization{A} <: AbstractLinearAlgorithm
92114
end
93115
SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())
94116

95-
function SciMLBase.solve!(cache::LinearCache, alg::SVDFactorization)
117+
function SciMLBase.solve(cache::LinearCache, alg::SVDFactorization)
96118
cache.A isa Union{AbstractMatrix, AbstractDiffEqOperator} || error("SVD is not defined for $(typeof(prob.A))")
97-
if cache.isfresh
98-
@set! cache.cacheval = svd!(cache.A; full=alg.full, alg=alg.alg)
99-
@set! cache.isfresh = false
100-
end
119+
cache = set_cacheval(cache,svd!(cache.A; full=alg.full, alg=alg.alg))
101120
ldiv!(cache.cacheval, cache.b)
102121
end
103122

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Test
55
A = rand(5, 5)
66
b = rand(5)
77
prob = LinearProblem(A, b)
8-
@test A * solve!(deepcopy(prob), LUFactorization()) b
9-
@test A * solve!(deepcopy(prob), QRFactorization()) b
10-
@test A * solve!(deepcopy(prob), SVDFactorization()) b
8+
@test A * solve(prob, LUFactorization();alias_A = false, alias_b = false) b
9+
@test A * solve(prob, QRFactorization();alias_A = false, alias_b = false) b
10+
@test A * solve(prob, SVDFactorization();alias_A = false, alias_b = false) b
1111
end

0 commit comments

Comments
 (0)