Skip to content

Commit 1665277

Browse files
use LinearSolve.jl
1 parent a88e748 commit 1665277

File tree

4 files changed

+69
-99
lines changed

4 files changed

+69
-99
lines changed

Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ version = "0.3.22"
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
88
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
99
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
10-
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
13-
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
1412
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1513
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1614
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
@@ -21,9 +19,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2119
ArrayInterfaceCore = "0.1.1"
2220
FiniteDiff = "2"
2321
ForwardDiff = "0.10.3"
24-
IterativeSolvers = "0.9"
2522
RecursiveArrayTools = "2"
26-
RecursiveFactorization = "0.1, 0.2"
2723
Reexport = "0.2, 1"
2824
SciMLBase = "1.32"
2925
Setfield = "0.7, 0.8, 1"

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import RecursiveFactorization
1616

1717
abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
1818
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
19-
abstract type AbstractNewtonAlgorithm{CS, AD} <: AbstractNonlinearSolveAlgorithm end
19+
abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <: AbstractNonlinearSolveAlgorithm end
2020
abstract type AbstractImmutableNonlinearSolver <: AbstractNonlinearSolveAlgorithm end
2121

2222
include("utils.jl")

src/raphson.jl

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS, AD}
2-
diff_type::DT
1+
struct NewtonRaphson{CS, AD, FDT, L, P, ST, CJ} <: AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
32
linsolve::L
3+
precs::P
44
end
55

6-
function NewtonRaphson(; autodiff = true, chunk_size = 12, diff_type = Val{:forward},
7-
linsolve = DEFAULT_LINSOLVE)
8-
NewtonRaphson{chunk_size, autodiff, typeof(diff_type), typeof(linsolve)}(diff_type,
9-
linsolve)
6+
function NewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(),
7+
standardtag = Val{true}(), concrete_jac = nothing,
8+
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS)
9+
NewtonRaphson{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
10+
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
11+
_unwrap_val(concrete_jac)}(linsolve, precs)
1012
end
1113

1214
mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC}
@@ -17,10 +19,64 @@ mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC}
1719
jac_config::JC
1820
end
1921

22+
function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
23+
du = nothing, u = nothing, p = nothing, t = nothing,
24+
weight = nothing, solverdata = nothing,
25+
reltol = nothing) where P
26+
A !== nothing && (linsolve = LinearSolve.set_A(linsolve, A))
27+
b !== nothing && (linsolve = LinearSolve.set_b(linsolve, b))
28+
linu !== nothing && (linsolve = LinearSolve.set_u(linsolve, linu))
29+
30+
Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
31+
linsolve.Pl
32+
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
33+
linsolve.Pr
34+
35+
_Pl, _Pr = precs(linsolve.A, du, u, p, nothing, A !== nothing, Plprev, Prprev,
36+
solverdata)
37+
if (_Pl !== nothing || _Pr !== nothing)
38+
_weight = weight === nothing ?
39+
(linsolve.Pr isa Diagonal ? linsolve.Pr.diag : linsolve.Pr.inner.diag) :
40+
weight
41+
Pl, Pr = wrapprecs(_Pl, _Pr, _weight)
42+
linsolve = LinearSolve.set_prec(linsolve, Pl, Pr)
43+
end
44+
45+
linres = if reltol === nothing
46+
solve(linsolve; reltol)
47+
else
48+
solve(linsolve; reltol)
49+
end
50+
51+
return linres
52+
end
53+
54+
function wrapprecs(_Pl, _Pr, weight)
55+
if _Pl !== nothing
56+
Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
57+
_Pl)
58+
else
59+
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
60+
end
61+
62+
if _Pr !== nothing
63+
Pr = LinearSolve.ComposePreconditioner(Diagonal(_vec(weight)), _Pr)
64+
else
65+
Pr = Diagonal(_vec(weight))
66+
end
67+
Pl, Pr
68+
end
69+
2070
function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true})
21-
uf = JacobianWrapper(f, p)
22-
linsolve = alg.linsolve(Val{:init}, f, u)
71+
uf = JacobianWrapper(f,p)
2372
J = false .* u .* u'
73+
74+
linprob = LinearProblem(W, _vec(zero(u)); u0 = _vec(zero(u)))
75+
Pl, Pr = wrapprecs(alg.precs(W, nothing, u, p, nothing, nothing, nothing, nothing,
76+
nothing)..., weight)
77+
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
78+
Pl = Pl, Pr = Pr)
79+
2480
du1 = zero(u)
2581
tmp = zero(u)
2682
if alg_autodiff(alg)
@@ -47,9 +103,12 @@ function perform_step(solver::NewtonImmutableSolver, alg::NewtonRaphson, ::Val{t
47103
@unpack J, linsolve, du1 = cache
48104
calc_J!(J, solver, cache)
49105
# u = u - J \ fu
50-
linsolve(du1, J, fu, true)
106+
linsolve = dolinsolve(alg.precs, solver.linsolve, A = J, b = fu, u = du1,
107+
p = p, reltol = solver.tol)
108+
@set! cache.linsolve = linsolve
51109
@. u = u - du1
52110
f(fu, u, p)
111+
53112
if solver.internalnorm(solver.fu) < solver.tol
54113
@set! solver.force_stop = true
55114
end

src/utils.jl

Lines changed: 0 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -100,91 +100,6 @@ function num_types_in_tuple(sig::UnionAll)
100100
length(Base.unwrap_unionall(sig).parameters)
101101
end
102102

103-
### Default Linsolve
104-
105-
# Try to be as smart as possible
106-
# lu! if Matrix
107-
# lu if sparse
108-
# gmres if operator
109-
110-
mutable struct DefaultLinSolve
111-
A::Any
112-
iterable::Any
113-
end
114-
DefaultLinSolve() = DefaultLinSolve(nothing, nothing)
115-
116-
function (p::DefaultLinSolve)(x, A, b, update_matrix = false; tol = nothing, kwargs...)
117-
if p.iterable isa Vector && eltype(p.iterable) <: LinearAlgebra.BlasInt # `iterable` here is the pivoting vector
118-
F = LU{eltype(A)}(A, p.iterable, zero(LinearAlgebra.BlasInt))
119-
ldiv!(x, F, b)
120-
return nothing
121-
end
122-
if update_matrix
123-
if typeof(A) <: Matrix
124-
blasvendor = BLAS.vendor()
125-
# if the user doesn't use OpenBLAS, we assume that is a better BLAS
126-
# implementation like MKL
127-
#
128-
# RecursiveFactorization seems to be consistantly winning below 100
129-
# https://discourse.julialang.org/t/ann-recursivefactorization-jl/39213
130-
if ArrayInterfaceCore.can_setindex(x) && (size(A, 1) <= 100 ||
131-
((blasvendor === :openblas || blasvendor === :openblas64) &&
132-
size(A, 1) <= 500))
133-
p.A = RecursiveFactorization.lu!(A)
134-
else
135-
p.A = lu!(A)
136-
end
137-
elseif typeof(A) <: Tridiagonal
138-
p.A = lu!(A)
139-
elseif typeof(A) <: Union{SymTridiagonal}
140-
p.A = ldlt!(A)
141-
elseif typeof(A) <: Union{Symmetric, Hermitian}
142-
p.A = bunchkaufman!(A)
143-
elseif typeof(A) <: SparseMatrixCSC
144-
p.A = lu(A)
145-
elseif ArrayInterfaceCore.isstructured(A)
146-
p.A = factorize(A)
147-
elseif !(typeof(A) <: AbstractDiffEqOperator)
148-
# Most likely QR is the one that is overloaded
149-
# Works on things like CuArrays
150-
p.A = qr(A)
151-
end
152-
end
153-
154-
if typeof(A) <: Union{Matrix, SymTridiagonal, Tridiagonal, Symmetric, Hermitian} # No 2-arg form for SparseArrays!
155-
x .= b
156-
ldiv!(p.A, x)
157-
# Missing a little bit of efficiency in a rare case
158-
#elseif typeof(A) <: DiffEqArrayOperator
159-
# ldiv!(x,p.A,b)
160-
elseif ArrayInterfaceCore.isstructured(A) || A isa SparseMatrixCSC
161-
ldiv!(x, p.A, b)
162-
elseif typeof(A) <: AbstractDiffEqOperator
163-
# No good starting guess, so guess zero
164-
if p.iterable === nothing
165-
p.iterable = IterativeSolvers.gmres_iterable!(x, A, b; initially_zero = true,
166-
restart = 5, maxiter = 5,
167-
tol = 1e-16, kwargs...)
168-
p.iterable.reltol = tol
169-
end
170-
x .= false
171-
iter = p.iterable
172-
purge_history!(iter, x, b)
173-
174-
for residual in iter
175-
end
176-
else
177-
ldiv!(x, p.A, b)
178-
end
179-
return nothing
180-
end
181-
182-
function (p::DefaultLinSolve)(::Type{Val{:init}}, f, u0_prototype)
183-
DefaultLinSolve()
184-
end
185-
186-
const DEFAULT_LINSOLVE = DefaultLinSolve()
187-
188103
@inline UNITLESS_ABS2(x) = real(abs2(x))
189104
@inline DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u)
190105
@inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}}

0 commit comments

Comments
 (0)