Skip to content

Commit f4232c0

Browse files
Merge branch 'main' into bug/fix-tol-eltype
2 parents 6315b55 + 1c679c6 commit f4232c0

15 files changed

+296
-37
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ jobs:
2828
group: 'LinearSolveHYPRE'
2929
- version: '1'
3030
group: 'LinearSolvePardiso'
31+
- version: '1'
32+
group: 'LinearSolveBandedMatrices'
3133
steps:
3234
- uses: actions/checkout@v4
3335
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.11.1"
4+
version = "2.13.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -67,7 +67,7 @@ EnzymeCore = "0.5, 0.6"
6767
FastLapackInterface = "1, 2"
6868
GPUArraysCore = "0.1"
6969
HYPRE = "1.4.0"
70-
IterativeSolvers = "0.9.2"
70+
IterativeSolvers = "0.9.3"
7171
KLU = "0.3.0, 0.4"
7272
KernelAbstractions = "0.9"
7373
Krylov = "0.9"
@@ -107,4 +107,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
107107
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
108108

109109
[targets]
110-
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff"]
110+
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices"]

docs/src/solvers/solvers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ IterativeSolversJL_CG
221221
IterativeSolversJL_GMRES
222222
IterativeSolversJL_BICGSTAB
223223
IterativeSolversJL_MINRES
224+
IterativeSolversJL_IDRS
224225
IterativeSolversJL
225226
```
226227

ext/LinearSolveBandedMatricesExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@ import LinearSolve: defaultalg,
55
do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice
66

77
# Defaults for BandedMatrices
8-
function defaultalg(A::BandedMatrix, b, ::OperatorAssumptions)
9-
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
8+
function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions)
9+
if oa.issq
10+
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
11+
elseif LinearSolve.is_underdetermined(A)
12+
error("No solver for underdetermined `A::BandedMatrix` is currently implemented!")
13+
else
14+
return DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
15+
end
1016
end
1117

1218
function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions)

ext/LinearSolveEnzymeExt.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.i
5858
d_b .= 0
5959
end
6060
else
61-
for i in 1:EnzymeRules.width(config)
62-
if d_A !== prob_d_A[i]
63-
prob_d_A[i] .+= d_A[i]
64-
d_A[i] .= 0
61+
for (_prob_d_A,_d_A,_prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)
62+
if _d_A !== _prob_d_A
63+
_prob_d_A .+= _d_A
64+
_d_A .= 0
6565
end
66-
if d_b !== prob_d_b[i]
67-
prob_d_b[i] .+= d_b[i]
68-
d_b[i] .= 0
66+
if _d_b !== _prob_d_b
67+
_prob_d_b .+= _d_b
68+
_d_b .= 0
6969
end
7070
end
7171
end
@@ -144,13 +144,15 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
144144
_linsolve.cacheval' \ dy
145145
elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization
146146
_linsolve.cacheval[1]' \ dy
147-
elseif _linsolve.alg isa AbstractKrylovSubspaceMethod
147+
elseif _linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod
148148
# Doesn't modify `A`, so it's safe to just reuse it
149149
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
150-
solve(invprob;
150+
solve(invprob, _linearsolve.alg;
151151
abstol = _linsolve.val.abstol,
152152
reltol = _linsolve.val.reltol,
153153
verbose = _linsolve.val.verbose)
154+
elseif _linsolve.alg isa LinearSolve.DefaultLinearSolver
155+
LinearSolve.defaultalg_adjoint_eval(_linsolve, dy)
154156
else
155157
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
156158
end
@@ -163,4 +165,4 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
163165
return (nothing,)
164166
end
165167

166-
end
168+
end

ext/LinearSolveIterativeSolversExt.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ function LinearSolve.IterativeSolversJL_GMRES(args...; kwargs...)
2727
generate_iterator = IterativeSolvers.gmres_iterable!,
2828
kwargs...)
2929
end
30+
function LinearSolve.IterativeSolversJL_IDRS(args...; kwargs...)
31+
IterativeSolversJL(args...;
32+
generate_iterator = IterativeSolvers.idrs_iterable!,
33+
kwargs...)
34+
end
35+
3036
function LinearSolve.IterativeSolversJL_BICGSTAB(args...; kwargs...)
3137
IterativeSolversJL(args...;
3238
generate_iterator = IterativeSolvers.bicgstabl_iterator!,
@@ -47,6 +53,7 @@ function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, max
4753
reltol,
4854
verbose::Bool, assumptions::OperatorAssumptions)
4955
restart = (alg.gmres_restart == 0) ? min(20, size(A, 1)) : alg.gmres_restart
56+
s = :idrs_s in keys(alg.kwargs) ? alg.kwargs.idrs_s : 4 # shadow space
5057

5158
kwargs = (abstol = abstol, reltol = reltol, maxiter = maxiters,
5259
alg.kwargs...)
@@ -59,6 +66,14 @@ function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, max
5966
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
6067
alg.generate_iterator(u, A, b; Pl = Pl, Pr = Pr, restart = restart,
6168
kwargs...)
69+
elseif alg.generate_iterator === IterativeSolvers.idrs_iterable!
70+
!!LinearSolve._isidentity_struct(Pr) &&
71+
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
72+
history = IterativeSolvers.ConvergenceHistory(partial=true)
73+
history[:abstol] = abstol
74+
history[:reltol] = reltol
75+
IterativeSolvers.idrs_iterable!(history, u, A, b, s, Pl, abstol, reltol, maxiters;
76+
alg.kwargs...)
6277
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
6378
!!LinearSolve._isidentity_struct(Pr) &&
6479
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
@@ -95,7 +110,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...
95110
end
96111
cache.verbose && println()
97112

98-
resid = cache.cacheval.residual
113+
resid = cache.cacheval isa IterativeSolvers.IDRSIterable ? cache.cacheval.R : cache.cacheval.residual
99114
if resid isa IterativeSolvers.Residual
100115
resid = resid.current
101116
end

src/LinearSolve.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
6262
needs_concrete_A(alg::AbstractSolveFunction) = false
6363

6464
# Util
65+
is_underdetermined(x) = false
66+
is_underdetermined(A::AbstractMatrix) = size(A, 1) < size(A, 2)
67+
is_underdetermined(A::AbstractSciMLOperator) = size(A, 1) < size(A, 2)
6568

6669
_isidentity_struct(A) = false
6770
_isidentity_struct::Number) = isone(λ)
@@ -96,6 +99,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
9699
NormalCholeskyFactorization
97100
AppleAccelerateLUFactorization
98101
MKLLUFactorization
102+
QRFactorizationPivoted
99103
end
100104

101105
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
@@ -143,6 +147,31 @@ end
143147
include("factorization_sparse.jl")
144148
end
145149

150+
# Solver Specific Traits
151+
## Needs Square Matrix
152+
"""
153+
needs_square_A(alg)
154+
155+
Returns `true` if the algorithm requires a square matrix.
156+
"""
157+
needs_square_A(::Nothing) = false # Linear Solve automatically will use a correct alg!
158+
needs_square_A(alg::SciMLLinearSolveAlgorithm) = true
159+
for alg in (:QRFactorization, :FastQRFactorization, :NormalCholeskyFactorization,
160+
:NormalBunchKaufmanFactorization)
161+
@eval needs_square_A(::$(alg)) = false
162+
end
163+
for kralg in (Krylov.lsmr!, Krylov.craigmr!)
164+
@eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false
165+
end
166+
for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization,
167+
:GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization,
168+
:RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
169+
:DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization,
170+
:CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization,
171+
:MKLLUFactorization, :MetalLUFactorization)
172+
@eval needs_square_A(::$(alg)) = true
173+
end
174+
146175
const IS_OPENBLAS = Ref(true)
147176
isopenblas() = IS_OPENBLAS[]
148177

@@ -188,7 +217,7 @@ export LinearSolveFunction, DirectLdiv!
188217
export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
189218
KrylovJL_BICGSTAB, KrylovJL_LSMR, KrylovJL_CRAIGMR,
190219
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
191-
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,
220+
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES, IterativeSolversJL_IDRS,
192221
KrylovKitJL, KrylovKitJL_CG, KrylovKitJL_GMRES
193222

194223
export SimpleGMRES

src/default.jl

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
needs_concrete_A(alg::DefaultLinearSolver) = true
22
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3-
T13, T14, T15, T16, T17, T18}
3+
T13, T14, T15, T16, T17, T18, T19}
44
LUFactorization::T1
55
QRFactorization::T2
66
DiagonalFactorization::T3
@@ -19,6 +19,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
1919
NormalCholeskyFactorization::T16
2020
AppleAccelerateLUFactorization::T17
2121
MKLLUFactorization::T18
22+
QRFactorizationPivoted::T19
2223
end
2324

2425
# Legacy fallback
@@ -168,8 +169,8 @@ function defaultalg(A, b, assump::OperatorAssumptions)
168169
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
169170
eltype(A) <: Union{Float32, Float64})
170171
DefaultAlgorithmChoice.RFLUFactorization
171-
#elseif A === nothing || A isa Matrix
172-
# alg = FastLUFactorization()
172+
#elseif A === nothing || A isa Matrix
173+
# alg = FastLUFactorization()
173174
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
174175
eltype(A) <: Union{Float32, Float64})
175176
DefaultAlgorithmChoice.MKLLUFactorization
@@ -199,9 +200,19 @@ function defaultalg(A, b, assump::OperatorAssumptions)
199200
elseif assump.condition === OperatorCondition.WellConditioned
200201
DefaultAlgorithmChoice.NormalCholeskyFactorization
201202
elseif assump.condition === OperatorCondition.IllConditioned
202-
DefaultAlgorithmChoice.QRFactorization
203+
if is_underdetermined(A)
204+
# Underdetermined
205+
DefaultAlgorithmChoice.QRFactorizationPivoted
206+
else
207+
DefaultAlgorithmChoice.QRFactorization
208+
end
203209
elseif assump.condition === OperatorCondition.VeryIllConditioned
204-
DefaultAlgorithmChoice.QRFactorization
210+
if is_underdetermined(A)
211+
# Underdetermined
212+
DefaultAlgorithmChoice.QRFactorizationPivoted
213+
else
214+
DefaultAlgorithmChoice.QRFactorization
215+
end
205216
elseif assump.condition === OperatorCondition.SuperIllConditioned
206217
DefaultAlgorithmChoice.SVDFactorization
207218
else
@@ -247,6 +258,12 @@ function algchoice_to_alg(alg::Symbol)
247258
NormalCholeskyFactorization()
248259
elseif alg === :AppleAccelerateLUFactorization
249260
AppleAccelerateLUFactorization()
261+
elseif alg === :QRFactorizationPivoted
262+
@static if VERSION v"1.7beta"
263+
QRFactorization(ColumnNorm())
264+
else
265+
QRFactorization(Val(true))
266+
end
250267
else
251268
error("Algorithm choice symbol $alg not allowed in the default")
252269
end
@@ -311,6 +328,12 @@ function defaultalg_symbol(::Type{T}) where {T}
311328
end
312329
defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization
313330

331+
@static if VERSION >= v"1.7"
332+
defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted
333+
else
334+
defaultalg_symbol(::Type{<:QRFactorization{Val{true}}}) = :QRFactorizationPivoted
335+
end
336+
314337
"""
315338
if alg.alg === DefaultAlgorithmChoice.LUFactorization
316339
SciMLBase.solve!(cache, LUFactorization(), args...; kwargs...))
@@ -339,3 +362,61 @@ end
339362
end
340363
ex = Expr(:if, ex.args...)
341364
end
365+
366+
"""
367+
```
368+
elseif DefaultAlgorithmChoice.LUFactorization === cache.alg
369+
(cache.cacheval.LUFactorization)' \\ dy
370+
else
371+
...
372+
end
373+
```
374+
"""
375+
@generated function defaultalg_adjoint_eval(cache::LinearCache, dy)
376+
ex = :()
377+
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
378+
newex = if alg in Symbol.((DefaultAlgorithmChoice.MKLLUFactorization,
379+
DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
380+
DefaultAlgorithmChoice.RFLUFactorization))
381+
quote
382+
getproperty(cache.cacheval,$(Meta.quot(alg)))[1]' \ dy
383+
end
384+
elseif alg in Symbol.((DefaultAlgorithmChoice.LUFactorization,
385+
DefaultAlgorithmChoice.QRFactorization,
386+
DefaultAlgorithmChoice.KLUFactorization,
387+
DefaultAlgorithmChoice.UMFPACKFactorization,
388+
DefaultAlgorithmChoice.LDLtFactorization,
389+
DefaultAlgorithmChoice.SparspakFactorization,
390+
DefaultAlgorithmChoice.BunchKaufmanFactorization,
391+
DefaultAlgorithmChoice.CHOLMODFactorization,
392+
DefaultAlgorithmChoice.SVDFactorization,
393+
DefaultAlgorithmChoice.CholeskyFactorization,
394+
DefaultAlgorithmChoice.NormalCholeskyFactorization,
395+
DefaultAlgorithmChoice.QRFactorizationPivoted,
396+
DefaultAlgorithmChoice.GenericLUFactorization))
397+
quote
398+
getproperty(cache.cacheval,$(Meta.quot(alg)))' \ dy
399+
end
400+
elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,))
401+
quote
402+
invprob = LinearSolve.LinearProblem(transpose(cache.A), dy)
403+
solve(invprob, cache.alg;
404+
abstol = cache.val.abstol,
405+
reltol = cache.val.reltol,
406+
verbose = cache.val.verbose)
407+
end
408+
else
409+
quote
410+
error("Default linear solver with algorithm $(alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
411+
end
412+
end
413+
414+
ex = if ex == :()
415+
Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex,
416+
:(error("Algorithm Choice not Allowed")))
417+
else
418+
Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex, ex)
419+
end
420+
end
421+
ex = Expr(:if, ex.args...)
422+
end

src/extension_algs.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,21 @@ A wrapper over the IterativeSolvers.jl GMRES.
309309
"""
310310
function IterativeSolversJL_GMRES end
311311

312+
"""
313+
```julia
314+
IterativeSolversJL_IDRS(args...; Pl = nothing, kwargs...)
315+
```
316+
317+
A wrapper over the IterativeSolvers.jl IDR(S).
318+
319+
320+
!!! note
321+
322+
Using this solver requires adding the package IterativeSolvers.jl, i.e. `using IterativeSolvers`
323+
324+
"""
325+
function IterativeSolversJL_IDRS end
326+
312327
"""
313328
```julia
314329
IterativeSolversJL_BICGSTAB(args...; Pl = nothing, Pr = nothing, kwargs...)

src/factorization.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ function QRFactorization(inplace = true)
158158
QRFactorization(pivot, 16, inplace)
159159
end
160160

161+
@static if VERSION v"1.7beta"
162+
function QRFactorization(pivot::LinearAlgebra.PivotingStrategy, inplace::Bool = true)
163+
QRFactorization(pivot, 16, inplace)
164+
end
165+
else
166+
function QRFactorization(pivot::Val, inplace::Bool = true)
167+
QRFactorization(pivot, 16, inplace)
168+
end
169+
end
170+
161171
function do_factorization(alg::QRFactorization, A, b, u)
162172
A = convert(AbstractMatrix, A)
163173
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)

0 commit comments

Comments
 (0)