Skip to content

Commit 031639f

Browse files
committed
Fix DFSane
1 parent 0e3efd7 commit 031639f

File tree

3 files changed

+63
-183
lines changed

3 files changed

+63
-183
lines changed

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ include("raphson.jl")
172172
# include("trustRegion.jl")
173173
# include("levenberg.jl")
174174
include("gaussnewton.jl")
175-
# include("dfsane.jl")
175+
include("dfsane.jl")
176176
include("pseudotransient.jl")
177177
include("broyden.jl")
178178
include("klement.jl")

src/dfsane.jl

Lines changed: 62 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
2-
DFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
3-
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
4-
n_exp::Int = 2, η_strategy::Function = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
5-
max_inner_iterations::Int = 1000)
2+
DFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, M::Int = 10,
3+
γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5, n_exp::Int = 2,
4+
η_strategy::Function = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
5+
max_inner_iterations::Int = 100)
66
77
A low-overhead and allocation-free implementation of the df-sane method for solving large-scale nonlinear
88
systems of equations. For in depth information about all the parameters and the algorithm,
@@ -39,34 +39,27 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_
3939
`f_n` the current residual. Should satisfy ``η > 0`` and ``∑ₖ ηₖ < ∞``. Defaults to
4040
``fn_1 / n^2``.
4141
- `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the
42-
algorithm. Defaults to `1000`.
42+
algorithm. Defaults to `100`.
4343
"""
44-
@concrete struct DFSane <: AbstractNonlinearSolveAlgorithm
45-
σ_min
46-
σ_max
47-
σ_1
48-
M::Int
49-
γ
50-
τ_min
51-
τ_max
52-
n_exp::Int
53-
η_strategy
54-
max_inner_iterations::Int
55-
end
56-
57-
function DFSane(; σ_min = 1e-10, σ_max = 1e+10, σ_1 = 1.0, M = 10, γ = 1e-4, τ_min = 0.1,
58-
τ_max = 0.5, n_exp = 2, η_strategy::F = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
59-
max_inner_iterations = 1000) where {F}
60-
return DFSane(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, n_exp, η_strategy,
61-
max_inner_iterations)
44+
@kwdef @concrete struct DFSane <: AbstractNonlinearSolveAlgorithm
45+
σ_min = 1e-10
46+
σ_max = 1e10
47+
σ_1 = 1.0
48+
M::Int = 10
49+
γ = 1e-4
50+
τ_min = 0.1
51+
τ_max = 0.5
52+
n_exp::Int = 2
53+
η_strategy = (fn_1, n, x_n, f_n) -> fn_1 / n^2
54+
max_inner_iterations::Int = 100
6255
end
6356

6457
@concrete mutable struct DFSaneCache{iip} <: AbstractNonlinearSolveCache{iip}
6558
alg
6659
u
67-
uprev
60+
u_cache
6861
fu
69-
fuprev
62+
fu_cache
7063
du
7164
history
7265
f_norm
@@ -93,36 +86,35 @@ end
9386
trace
9487
end
9588

96-
get_fu(cache::DFSaneCache) = cache.fu
97-
set_fu!(cache::DFSaneCache, fu) = (cache.fu = fu)
98-
9989
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
10090
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
10191
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
10292
kwargs...) where {uType, iip, F}
103-
u = alias_u0 ? prob.u0 : deepcopy(prob.u0)
93+
u = __maybe_unaliased(prob.u0, alias_u0)
10494
T = eltype(u)
10595

106-
du, uprev = copy(u), copy(u)
96+
@bb du = similar(u)
97+
@bb u_cache = copy(u)
98+
10799
fu = evaluate_f(prob, u)
108-
fuprev = copy(fu)
100+
@bb fu_cache = copy(fu)
109101

110102
f_norm = internalnorm(fu)^alg.n_exp
111103
f_norm_0 = f_norm
112104

113105
history = fill(f_norm, alg.M)
114106

115-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, uprev,
107+
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u_cache,
116108
termination_condition)
117109
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)
118110

119-
return DFSaneCache{iip}(alg, u, uprev, fu, fuprev, du, history, f_norm, f_norm_0, alg.M,
120-
T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ), T(alg.τ_min),
111+
return DFSaneCache{iip}(alg, u, u_cache, fu, fu_cache, du, history, f_norm, f_norm_0,
112+
alg.M, T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ), T(alg.τ_min),
121113
T(alg.τ_max), alg.n_exp, prob.p, false, maxiters, internalnorm, ReturnCode.Default,
122114
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
123115
end
124116

125-
function perform_step!(cache::DFSaneCache{true})
117+
function perform_step!(cache::DFSaneCache{iip}) where {iip}
126118
@unpack alg, f_norm, σ_n, σ_min, σ_max, α_1, γ, τ_min, τ_max, n_exp, M, prob = cache
127119
T = eltype(cache.u)
128120
f_norm_old = f_norm
@@ -131,128 +123,64 @@ function perform_step!(cache::DFSaneCache{true})
131123
σ_n = sign(σ_n) * clamp(abs(σ_n), σ_min, σ_max)
132124

133125
# Line search direction
134-
@. cache.du = -σ_n * cache.fuprev
126+
@bb @. cache.du = -σ_n * cache.fu
135127

136128
η = alg.η_strategy(cache.f_norm_0, cache.stats.nsteps, cache.u, cache.fu)
137129

138130
f_bar = maximum(cache.history)
139131
α₊ = α_1
140132
α₋ = α_1
141-
_axpy!(α₊, cache.du, cache.u)
142-
143-
prob.f(cache.fu, cache.u, cache.p)
144-
f_norm = cache.internalnorm(cache.fu)^n_exp
145-
146-
# TODO: Failure mode with inner line search failed?
147-
for _ in 1:(cache.alg.max_inner_iterations)
148-
c = f_bar + η - γ * α₊^2 * f_norm_old
149-
150-
f_norm c && break
151-
152-
α₊ = α₊ * clamp(α₊ * f_norm_old / (f_norm + (T(2) * α₊ - T(1)) * f_norm_old),
153-
τ_min, τ_max)
154-
@. cache.u = cache.uprev - α₋ * cache.du
155-
156-
prob.f(cache.fu, cache.u, cache.p)
157-
f_norm = cache.internalnorm(cache.fu)^n_exp
158-
159-
f_norm c && break
160-
161-
α₋ = α₋ * clamp(α₋ * f_norm_old / (f_norm + (T(2) * α₋ - T(1)) * f_norm_old),
162-
τ_min, τ_max)
163-
@. cache.u = cache.uprev + α₊ * cache.du
164-
165-
prob.f(cache.fu, cache.u, cache.p)
166-
f_norm = cache.internalnorm(cache.fu)^n_exp
167-
end
168-
169-
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), nothing,
170-
cache.du, α₊)
171133

172-
check_and_update!(cache, cache.fu, cache.u, cache.uprev)
134+
@bb axpy!(α₊, cache.du, cache.u)
173135

174-
# Update spectral parameter
175-
@. cache.uprev = cache.u - cache.uprev
176-
@. cache.fuprev = cache.fu - cache.fuprev
177-
178-
α₊ = sum(abs2, cache.uprev)
179-
@. cache.uprev *= cache.fuprev
180-
α₋ = sum(cache.uprev)
181-
cache.σ_n = α₊ / α₋
182-
183-
# Spectral parameter bounds check
184-
if !(σ_min abs(cache.σ_n) σ_max)
185-
test_norm = sqrt(sum(abs2, cache.fuprev))
186-
cache.σ_n = clamp(inv(test_norm), T(1), T(1e5))
187-
end
188-
189-
# Take step
190-
@. cache.uprev = cache.u
191-
@. cache.fuprev = cache.fu
192-
cache.f_norm = f_norm
193-
194-
# Update history
195-
cache.history[cache.stats.nsteps % M + 1] = f_norm
196-
cache.stats.nf += 1
197-
return nothing
198-
end
199-
200-
function perform_step!(cache::DFSaneCache{false})
201-
@unpack alg, f_norm, σ_n, σ_min, σ_max, α_1, γ, τ_min, τ_max, n_exp, M, prob = cache
202-
T = eltype(cache.u)
203-
f_norm_old = f_norm
204-
205-
# Spectral parameter range check
206-
σ_n = sign(σ_n) * clamp(abs(σ_n), σ_min, σ_max)
207-
208-
# Line search direction
209-
cache.du = @. -σ_n * cache.fuprev
210-
211-
η = alg.η_strategy(cache.f_norm_0, cache.stats.nsteps, cache.u, cache.fu)
212-
213-
f_bar = maximum(cache.history)
214-
α₊ = α_1
215-
α₋ = α_1
216-
cache.u = @. cache.uprev + α₊ * cache.du
217-
218-
cache.fu = prob.f(cache.u, cache.p)
136+
evaluate_f(cache, cache.u, cache.p)
219137
f_norm = cache.internalnorm(cache.fu)^n_exp
138+
α = α₊
220139

221-
# TODO: Failure mode with inner line search failed?
222-
for _ in 1:(cache.alg.max_inner_iterations)
223-
c = f_bar + η - γ * α₊^2 * f_norm_old
224-
225-
f_norm c && break
140+
inner_converged = false
141+
for k in 1:(cache.alg.max_inner_iterations)
142+
if f_norm f_bar + η - γ * α₊^2 * f_norm_old
143+
α = α₊
144+
inner_converged = true
145+
break
146+
end
226147

227148
α₊ = α₊ * clamp(α₊ * f_norm_old / (f_norm + (T(2) * α₊ - T(1)) * f_norm_old),
228149
τ_min, τ_max)
229-
cache.u = @. cache.uprev - α₋ * cache.du
150+
@bb axpy!(-α₋, cache.du, cache.u)
230151

231-
cache.fu = prob.f(cache.u, cache.p)
152+
evaluate_f(cache, cache.u, cache.p)
232153
f_norm = cache.internalnorm(cache.fu)^n_exp
233154

234-
f_norm c && break
155+
if f_norm f_bar + η - γ * α₋^2 * f_norm_old
156+
α = α₋
157+
inner_converged = true
158+
break
159+
end
235160

236161
α₋ = α₋ * clamp(α₋ * f_norm_old / (f_norm + (T(2) * α₋ - T(1)) * f_norm_old),
237162
τ_min, τ_max)
238-
cache.u = @. cache.uprev + α₊ * cache.du
163+
@bb axpy!(α₊, cache.du, cache.u)
239164

240-
cache.fu = prob.f(cache.u, cache.p)
165+
evaluate_f(cache, cache.u, cache.p)
241166
f_norm = cache.internalnorm(cache.fu)^n_exp
242167
end
243168

244-
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), nothing,
245-
cache.du, α₊)
169+
if !inner_converged
170+
cache.retcode = ReturnCode.ConvergenceFailure
171+
cache.force_stop = true
172+
end
246173

247-
check_and_update!(cache, cache.fu, cache.u, cache.uprev)
174+
update_trace!(cache, α)
175+
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
248176

249177
# Update spectral parameter
250-
cache.uprev = @. cache.u - cache.uprev
251-
cache.fuprev = @. cache.fu - cache.fuprev
178+
@bb @. cache.u_cache = cache.u - cache.u_cache
179+
@bb @. cache.fu_cache = cache.fu - cache.fu_cache
252180

253-
α₊ = sum(abs2, cache.uprev)
254-
cache.uprev = @. cache.uprev * cache.fuprev
255-
α₋ = sum(cache.uprev)
181+
α₊ = sum(abs2, cache.u_cache)
182+
@bb @. cache.u_cache *= cache.fu_cache
183+
α₋ = sum(cache.u_cache)
256184
cache.σ_n = α₊ / α₋
257185

258186
# Spectral parameter bounds check
@@ -262,8 +190,8 @@ function perform_step!(cache::DFSaneCache{false})
262190
end
263191

264192
# Take step
265-
cache.uprev = cache.u
266-
cache.fuprev = cache.fu
193+
@bb copyto!(cache.u_cache, cache.u)
194+
@bb copyto!(cache.fu_cache, cache.fu)
267195
cache.f_norm = f_norm
268196

269197
# Update history
@@ -272,41 +200,8 @@ function perform_step!(cache::DFSaneCache{false})
272200
return nothing
273201
end
274202

275-
function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.u; p = cache.p,
276-
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
277-
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
278-
cache.p = p
279-
if iip
280-
recursivecopy!(cache.u, u0)
281-
recursivecopy!(cache.uprev, u0)
282-
cache.prob.f(cache.fu, cache.u, p)
283-
cache.prob.f(cache.fuprev, cache.uprev, p)
284-
else
285-
cache.u = u0
286-
cache.uprev = u0
287-
cache.fu = cache.prob.f(cache.u, p)
288-
cache.fuprev = cache.prob.f(cache.uprev, p)
289-
end
290-
203+
function __reinit_internal!(cache::DFSaneCache; kwargs...)
291204
cache.f_norm = cache.internalnorm(cache.fu)^cache.n_exp
292205
cache.f_norm_0 = cache.f_norm
293-
294-
fill!(cache.history, cache.f_norm)
295-
296-
T = eltype(cache.u)
297-
cache.σ_n = T(cache.alg.σ_1)
298-
299-
reset!(cache.trace)
300-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
301-
termination_condition)
302-
303-
cache.abstol = abstol
304-
cache.reltol = reltol
305-
cache.tc_cache = tc_cache
306-
cache.maxiters = maxiters
307-
cache.stats.nf = 1
308-
cache.stats.nsteps = 1
309-
cache.force_stop = false
310-
cache.retcode = ReturnCode.Default
311-
return cache
206+
return
312207
end

src/utils.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,6 @@ function evaluate_f(cache, u, p)
206206
return nothing
207207
end
208208

209-
"""
210-
__matmul!(C, A, B)
211-
212-
Defaults to `mul!(C, A, B)`. However, for sparse matrices uses `C .= A * B`.
213-
"""
214-
__matmul!(C, A, B) = mul!(C, A, B)
215-
__matmul!(C::AbstractSparseMatrix, A, B) = C .= A * B
216-
217209
# Concretize Algorithms
218210
function get_concrete_algorithm(alg, prob)
219211
!hasfield(typeof(alg), :ad) && return alg
@@ -381,15 +373,8 @@ function __try_factorize_and_check_singular!(linsolve, X)
381373
end
382374
__try_factorize_and_check_singular!(::FakeLinearSolveJLCache, x) = _issingular(x), false
383375

384-
# TODO: Remove. handled in MaybeInplace.jl
385-
@generated function _axpy!(α, x, y)
386-
hasmethod(axpy!, Tuple{α, x, y}) && return :(axpy!(α, x, y))
387-
return :(@. y += α * x)
388-
end
389-
390376
# Non-square matrix
391377
@inline __needs_square_A(_, ::Number) = true
392-
# @inline __needs_square_A(_, ::StaticArray) = true
393378
@inline __needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)
394379

395380
# Define special concatenation for certain Array combinations

0 commit comments

Comments
 (0)