Skip to content

Commit f4bbc08

Browse files
Merge pull request #454 from SciML/operator_defaults
Fix and test solvers for non-square operators
2 parents 30ff6eb + 40743e1 commit f4bbc08

File tree

4 files changed

+66
-9
lines changed

4 files changed

+66
-9
lines changed

src/LinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ EnumX.@enumx DefaultAlgorithmChoice begin
103103
AppleAccelerateLUFactorization
104104
MKLLUFactorization
105105
QRFactorizationPivoted
106+
KrylovJL_CRAIGMR
107+
KrylovJL_LSMR
106108
end
107109

108110
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm

src/default.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
needs_concrete_A(alg::DefaultLinearSolver) = true
22
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3-
T13, T14, T15, T16, T17, T18, T19}
3+
T13, T14, T15, T16, T17, T18, T19, T20, T21}
44
LUFactorization::T1
55
QRFactorization::T2
66
DiagonalFactorization::T3
@@ -20,6 +20,8 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
2020
AppleAccelerateLUFactorization::T17
2121
MKLLUFactorization::T18
2222
QRFactorizationPivoted::T19
23+
KrylovJL_CRAIGMR::T20
24+
KrylovJL_LSMR::T21
2325
end
2426

2527
# Legacy fallback
@@ -254,11 +256,11 @@ function algchoice_to_alg(alg::Symbol)
254256
elseif alg === :AppleAccelerateLUFactorization
255257
AppleAccelerateLUFactorization()
256258
elseif alg === :QRFactorizationPivoted
257-
@static if VERSION v"1.7beta"
258-
QRFactorization(ColumnNorm())
259-
else
260-
QRFactorization(Val(true))
261-
end
259+
QRFactorization(ColumnNorm())
260+
elseif alg === :KrylovJL_CRAIGMR
261+
KrylovJL_CRAIGMR()
262+
elseif alg === :KrylovJL_LSMR
263+
KrylovJL_LSMR()
262264
else
263265
error("Algorithm choice symbol $alg not allowed in the default")
264266
end
@@ -387,7 +389,7 @@ end
387389
quote
388390
getproperty(cache.cacheval,$(Meta.quot(alg)))' \ dy
389391
end
390-
elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,))
392+
elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,DefaultAlgorithmChoice.KrylovJL_LSMR, DefaultAlgorithmChoice.KrylovJL_CRAIGMR))
391393
quote
392394
invprob = LinearSolve.LinearProblem(transpose(cache.A), dy)
393395
solve(invprob, cache.alg;

src/iterative_wrappers.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,21 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
243243
itmax = cache.maxiters
244244
verbose = cache.verbose ? 1 : 0
245245

246-
args = (@get_cacheval(cache, :KrylovJL_GMRES), cache.A, cache.b)
246+
cacheval = if cache.alg isa DefaultLinearSolver
247+
if alg.KrylovAlg === Krylov.gmres!
248+
@get_cacheval(cache, :KrylovJL_GMRES)
249+
elseif alg.KrylovAlg === Krylov.craigmr!
250+
@get_cacheval(cache, :KrylovJL_CRAIGMR)
251+
elseif alg.KrylovAlg === Krylov.lsmr!
252+
@get_cacheval(cache, :KrylovJL_LSMR)
253+
else
254+
error("Default linear solver can only be these three choices! Report this bug!")
255+
end
256+
else
257+
cache.cacheval
258+
end
259+
260+
args = (cacheval, cache.A, cache.b)
247261
kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose,
248262
ldiv = true, history = true, alg.kwargs...)
249263

@@ -268,7 +282,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
268282
end
269283

270284
stats = @get_cacheval(cache, :KrylovJL_GMRES).stats
271-
resid = stats.residuals |> last
285+
resid = !isempty(stats.residuals) ? last(stats.residuals) : zero(eltype(stats.residuals))
272286

273287
retcode = if !stats.solved
274288
if stats.status == "maximum number of iterations exceeded"

test/default_algs.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,42 @@ prob = LinearProblem(sparse(A), b)
6767

6868
prob = LinearProblem(big.(rand(10, 10)), big.(zeros(10)))
6969
solve(prob)
70+
71+
## Operator defaults
72+
## https://github.com/SciML/LinearSolve.jl/issues/414
73+
74+
m, n = 2, 2
75+
A = rand(m, n)
76+
b = rand(m)
77+
x = rand(n)
78+
f = (du, u, p, t) -> mul!(du, A, u)
79+
fadj = (du, u, p, t) -> mul!(du, A', u)
80+
fo = FunctionOperator(f, x, b; op_adjoint = fadj)
81+
prob = LinearProblem(fo, b)
82+
sol1 = solve(prob)
83+
sol2 = solve(prob, LinearSolve.KrylovJL_GMRES())
84+
@test sol1.u == sol2.u
85+
86+
m, n = 3, 2
87+
A = rand(m, n)
88+
b = rand(m)
89+
x = rand(n)
90+
f = (du, u, p, t) -> mul!(du, A, u)
91+
fadj = (du, u, p, t) -> mul!(du, A', u)
92+
fo = FunctionOperator(f, x, b; op_adjoint = fadj)
93+
prob = LinearProblem(fo, b)
94+
sol1 = solve(prob)
95+
sol2 = solve(prob, LinearSolve.KrylovJL_LSMR())
96+
@test sol1.u == sol2.u
97+
98+
m, n = 2, 3
99+
A = rand(m, n)
100+
b = rand(m)
101+
x = rand(n)
102+
f = (du, u, p, t) -> mul!(du, A, u)
103+
fadj = (du, u, p, t) -> mul!(du, A', u)
104+
fo = FunctionOperator(f, x, b; op_adjoint = fadj)
105+
prob = LinearProblem(fo, b)
106+
sol1 = solve(prob)
107+
sol2 = solve(prob, LinearSolve.KrylovJL_CRAIGMR())
108+
@test sol1.u == sol2.u

0 commit comments

Comments
 (0)