Skip to content

Commit ebde14c

Browse files
Merge pull request #399 from ma-sadeghi/bug/fix-tol-eltype
Fix type promote in case b is complex
2 parents 9574afc + d3bc1ce commit ebde14c

File tree

3 files changed

+54
-32
lines changed

3 files changed

+54
-32
lines changed

src/common.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
115115
args...;
116116
alias_A = default_alias_A(alg, prob.A, prob.b),
117117
alias_b = default_alias_b(alg, prob.A, prob.b),
118-
abstol = default_tol(eltype(prob.b)),
119-
reltol = default_tol(eltype(prob.b)),
118+
abstol = default_tol(real(eltype(prob.b))),
119+
reltol = default_tol(real(eltype(prob.b))),
120120
maxiters::Int = length(prob.b),
121121
verbose::Bool = false,
122122
Pl = IdentityOperator(size(prob.A)[1]),
@@ -151,8 +151,8 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
151151
end
152152

153153
# Guard against type mismatch for user-specified reltol/abstol
154-
reltol = eltype(prob.b)(reltol)
155-
abstol = eltype(prob.b)(abstol)
154+
reltol = real(eltype(prob.b))(reltol)
155+
abstol = real(eltype(prob.b))(abstol)
156156

157157
cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose,
158158
assumptions)

src/factorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function _ldiv!(x::Vector, A::Factorization, b::Vector)
1616
ldiv!(A, x)
1717
end
1818

19-
#RF Bad fallback: will fail if `A` is just a stand-in
19+
# RF Bad fallback: will fail if `A` is just a stand-in
2020
# This should instead just create the factorization type.
2121
function init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
2222
reltol, verbose::Bool, assumptions::OperatorAssumptions)

test/basictests.jl

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,52 @@ const Dual64 = ForwardDiff.Dual{Nothing, Float64, 1}
99
n = 8
1010
A = Matrix(I, n, n)
1111
b = ones(n)
12+
# Real-valued systems
1213
A1 = A / 1;
1314
b1 = rand(n);
1415
x1 = zero(b);
16+
# A2 is similar to A1; created to test cache reuse
1517
A2 = A / 2;
1618
b2 = rand(n);
1719
x2 = zero(b);
20+
# Complex systems + mismatched types with eltype(tol)
21+
A3 = A1 .|> ComplexF32
22+
b3 = b1 .|> ComplexF32
23+
x3 = x1 .|> ComplexF32
24+
# A4 is similar to A3; created to test cache reuse
25+
A4 = A2 .|> ComplexF32
26+
b4 = b2 .|> ComplexF32
27+
x4 = x2 .|> ComplexF32
1828

1929
prob1 = LinearProblem(A1, b1; u0 = x1)
2030
prob2 = LinearProblem(A2, b2; u0 = x2)
31+
prob3 = LinearProblem(A3, b3; u0 = x3)
32+
prob4 = LinearProblem(A4, b4; u0 = x4)
2133

2234
cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)
2335

2436
function test_interface(alg, prob1, prob2)
25-
A1 = prob1.A
26-
b1 = prob1.b
27-
x1 = prob1.u0
28-
A2 = prob2.A
29-
b2 = prob2.b
30-
x2 = prob2.u0
37+
A1, b1 = prob1.A, prob1.b
38+
A2, b2 = prob2.A, prob2.b
3139

3240
sol = solve(prob1, alg; cache_kwargs...)
3341
@test A1 * sol.u b1
3442

43+
sol = solve(prob2, alg; cache_kwargs...)
44+
@test A2 * sol.u b2
45+
46+
# Test cache resue: base mechanism
3547
cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
3648
sol = solve!(cache)
3749
@test A1 * sol.u b1
50+
51+
# Test cache resue: only A changes
3852
cache.A = deepcopy(A2)
3953
sol = solve!(cache; cache_kwargs...)
4054
@test A2 * sol.u b1
4155

42-
cache.A = A2
56+
# Test cache resue: both A and b change
57+
cache.A = deepcopy(A2)
4358
cache.b = b2
4459
sol = solve!(cache; cache_kwargs...)
4560
@test A2 * sol.u b2
@@ -50,6 +65,7 @@ end
5065
@testset "LinearSolve" begin
5166
@testset "Default Linear Solver" begin
5267
test_interface(nothing, prob1, prob2)
68+
test_interface(nothing, prob3, prob4)
5369

5470
A1 = prob1.A * prob1.A'
5571
b1 = prob1.b
@@ -202,25 +218,24 @@ end
202218
end
203219
end
204220

205-
test_algs = if VERSION >= v"1.9" && LinearSolve.usemkl
206-
(LUFactorization(),
207-
QRFactorization(),
208-
SVDFactorization(),
209-
RFLUFactorization(),
210-
MKLLUFactorization(),
211-
LinearSolve.defaultalg(prob1.A, prob1.b))
212-
else
213-
(LUFactorization(),
214-
QRFactorization(),
215-
SVDFactorization(),
216-
RFLUFactorization(),
217-
LinearSolve.defaultalg(prob1.A, prob1.b))
221+
222+
test_algs = [
223+
LUFactorization(),
224+
QRFactorization(),
225+
SVDFactorization(),
226+
RFLUFactorization(),
227+
LinearSolve.defaultalg(prob1.A, prob1.b),
228+
]
229+
230+
if VERSION >= v"1.9" && LinearSolve.usemkl
231+
push!(test_algs, MKLLUFactorization())
218232
end
219233

220234
@testset "Concrete Factorizations" begin
221235
for alg in test_algs
222236
@testset "$alg" begin
223237
test_interface(alg, prob1, prob2)
238+
VERSION >= v"1.9" && (alg isa MKLLUFactorization || test_interface(alg, prob3, prob4))
224239
end
225240
end
226241
if LinearSolve.appleaccelerate_isavailable()
@@ -232,15 +247,16 @@ end
232247
for fact_alg in (lu, lu!,
233248
qr, qr!,
234249
cholesky,
235-
#cholesky!,
236-
#ldlt, ldlt!,
250+
# cholesky!,
251+
# ldlt, ldlt!,
237252
bunchkaufman, bunchkaufman!,
238253
lq, lq!,
239254
svd, svd!,
240255
LinearAlgebra.factorize)
241256
@testset "fact_alg = $fact_alg" begin
242257
alg = GenericFactorization(fact_alg = fact_alg)
243258
test_interface(alg, prob1, prob2)
259+
test_interface(alg, prob3, prob4)
244260
end
245261
end
246262
end
@@ -251,13 +267,17 @@ end
251267

252268
@testset "KrylovJL" begin
253269
kwargs = (; gmres_restart = 5)
254-
for alg in (("Default", KrylovJL(kwargs...)),
270+
algorithms = (
271+
("Default", KrylovJL(kwargs...)),
255272
("CG", KrylovJL_CG(kwargs...)),
256273
("GMRES", KrylovJL_GMRES(kwargs...)),
257-
# ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)),
258-
("MINRES", KrylovJL_MINRES(kwargs...)))
259-
@testset "$(alg[1])" begin
260-
test_interface(alg[2], prob1, prob2)
274+
# ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)),
275+
("MINRES", KrylovJL_MINRES(kwargs...))
276+
)
277+
for (name, algorithm) in algorithms
278+
@testset "$name" begin
279+
test_interface(algorithm, prob1, prob2)
280+
test_interface(algorithm, prob3, prob4)
261281
end
262282
end
263283
end
@@ -274,6 +294,7 @@ end
274294
)
275295
@testset "$(alg[1])" begin
276296
test_interface(alg[2], prob1, prob2)
297+
test_interface(alg[2], prob3, prob4)
277298
end
278299
end
279300
end
@@ -287,6 +308,7 @@ end
287308
("GMRES", KrylovKitJL_GMRES(kwargs...)))
288309
@testset "$(alg[1])" begin
289310
test_interface(alg[2], prob1, prob2)
311+
test_interface(alg[2], prob3, prob4)
290312
end
291313
@test alg[2] isa KrylovKitJL
292314
end

0 commit comments

Comments
 (0)