Skip to content

Commit 7cc8f3b

Browse files
committed
commiting updates
1 parent 73d8808 commit 7cc8f3b

File tree

3 files changed

+93
-33
lines changed

3 files changed

+93
-33
lines changed

src/LinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module LinearSolve
22

33
using ArrayInterface
44
using RecursiveFactorization
5-
using Base: cache_dependencies, Bool, eltype
5+
using Base: cache_dependencies, Bool
6+
import Base: eltype, *
67
using LinearAlgebra
78
using SparseArrays
89
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm

src/wrappers.jl

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,75 +2,82 @@
22
## Preconditioners
33

44
"""
5+
LEFT
56
P * x = x .* (1/P.s)
6-
77
Pi * x = x .* (P.s)
8+
9+
Right
10+
P * x = x .* (P.s)
11+
Pi * x = x .* (1/P.s)
812
"""
913
struct ScaleVector{T}
1014
s::T
1115
isleft::Bool
1216
end
1317

18+
Base.eltype(A::ScaleVector) = eltype(A.s)
19+
20+
#function Base.*(A::ScaleVector, x)
21+
# y = similar(x)
22+
# mul!(y, A, x)
23+
#end
24+
1425
# y = A x
1526
function LinearAlgebra.mul!(y, A::ScaleVector, x)
1627
A.s == one(eltype(A.s)) && return y = x
1728

18-
if A.isleft
19-
@. x = x / A.s
20-
else
21-
@. x = x * A.s
22-
end
29+
s = A.isleft ? 1/A.s : A.s
30+
mul!(y, s, x)
31+
2332
end
2433

2534
# A B α + C β
2635
function LinearAlgebra.mul!(C, A::ScaleVector, B, α, β)
27-
28-
tmp = zero(B)
29-
C = β * C + α * mul!(tmp, A, B)
36+
A.s == one(eltype(A.s)) && return @. C = α * B + β * C
37+
38+
s = A.isleft ? 1/A.s : A.s
39+
mul!(C, s, B, α, β)
3040
end
3141

3242
function LinearAlgebra.ldiv!(A::ScaleVector, x)
3343
A.s == one(eltype(A.s)) && return x
3444

35-
if A.isleft
36-
@. x = x * A.s
37-
else
38-
@. x = x / A.s
39-
end
45+
s = A.isleft ? A.s : 1/A.s
46+
@. x = x * s
4047
end
4148

4249
function LinearAlgebra.ldiv!(y, A::ScaleVector, x)
43-
P.s == one(eltype(A.s)) && return y = x
50+
A.s == one(eltype(A.s)) && return y = x
4451

45-
if A.isleft
46-
@. y = x * A.s
47-
else
48-
@. y = x / A.s
49-
end
52+
s = A.isleft ? A.s : 1/A.s
53+
mul!(y, s, x)
5054
end
5155

52-
Base.eltype(A::ScaleVector) = eltype(A.s)
53-
5456
"""
5557
C * x = P * Q * x
56-
5758
Ci * x = Qi * Pi * x
5859
"""
5960
struct ComposePreconditioner{Ti,To}
6061
inner::Ti
6162
outer::To
6263
end
6364

65+
Base.eltype(A::ComposePreconditioner) = Float64 #eltype(A.inner)
66+
6467
# y = A x
6568
function LinearAlgebra.mul!(y, A::ComposePreconditioner, x)
6669
@unpack inner, outer = A
67-
mul!(y, inner, x)
68-
y = outer * y
70+
tmp = similar(y)
71+
mul!(tmp, outer, x)
72+
mul!(y, inner, tmp)
6973
end
7074

7175
# A B α + C β
7276
function LinearAlgebra.mul!(C, A::ComposePreconditioner, B, α, β)
7377
@unpack inner, outer = A
78+
tmp = similar(B)
79+
mul!(tmp, inner, B)
80+
mul!(C, outer, tmp, α, β)
7481
end
7582

7683
function LinearAlgebra.ldiv!(A::ComposePreconditioner, x)
@@ -190,8 +197,11 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
190197
cache = set_cacheval(cache, solver)
191198
end
192199

193-
M = alg.Pl #ComposePreconditioner(alg.Pl, cache.Pl) # left precond
194-
N = alg.Pr #ComposePreconditioner(alg.Pr, cache.Pr) # right
200+
M = alg.Pl
201+
N = alg.Pr
202+
203+
# M = ComposePreconditioner(alg.Pl, cache.Pl) # left precond
204+
# N = ComposePreconditioner(alg.Pr, cache.Pr) # right
195205

196206
atol = cache.abstol
197207
rtol = cache.reltol
@@ -265,11 +275,8 @@ IterativeSolversJL_MINRES(args...;kwargs...) =
265275
function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
266276
@unpack A, b, u = cache
267277

268-
Pl = alg.Pl
269-
Pr = alg.Pr
270-
271-
# Pl = ComposePreconditioner(alg.Pl, cache.Pl)
272-
# Pr = ComposePreconditioner(alg.Pr, cache.Pr)
278+
Pl = ComposePreconditioner(alg.Pl, cache.Pl)
279+
Pr = ComposePreconditioner(alg.Pr, cache.Pr)
273280

274281
abstol = cache.abstol
275282
reltol = cache.reltol

test/runtests.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ function test_interface(alg, prob1, prob2)
3434
return
3535
end
3636

37+
@testset "LinearSolve" begin
38+
3739
@testset "Default Linear Solver" begin
3840
test_interface(nothing, prob1, prob2)
3941

@@ -123,3 +125,53 @@ end
123125
end
124126
end
125127
end
128+
129+
@testset "Preconditioners" begin
130+
@testset "ScaleVector" begin
131+
s = rand()
132+
α = rand()
133+
β = rand()
134+
135+
x = rand(n,n)
136+
y = rand(n,n)
137+
138+
Pl = LinearSolve.ScaleVector(s, true)
139+
Pr = LinearSolve.ScaleVector(s, false)
140+
141+
mul!(y, Pl, x)
142+
mul!(y, Pr, x)
143+
144+
mul!(y, Pl, x, α, β)
145+
mul!(y, Pr, x, α, β)
146+
147+
ldiv!(Pl, x)
148+
ldiv!(Pr, x)
149+
150+
ldiv!(y, Pl, x)
151+
ldiv!(y, Pr, x)
152+
153+
end
154+
155+
@testset "ComposePreconditioenr" begin
156+
s = rand()
157+
α = rand()
158+
β = rand()
159+
160+
x = rand(n,n)
161+
y = rand(n,n)
162+
163+
Pi = LinearSolve.ScaleVector(s, true)
164+
Po = LinearSolve.ScaleVector(s, false)
165+
166+
P = LinearSolve.ComposePreconditioner(Pi,Po)
167+
168+
mul!(y, P, x)
169+
mul!(y, P, x, α, β)
170+
171+
ldiv!(P, x)
172+
ldiv!(y, P, x)
173+
174+
end
175+
end
176+
177+
end # testset

0 commit comments

Comments
 (0)