Skip to content

Commit 3bf5853

Browse files
committed
Fix line search direction for some algorithms
1 parent c7ca39a commit 3bf5853

File tree

9 files changed

+54
-81
lines changed

9 files changed

+54
-81
lines changed

src/broyden.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ get_fu(cache::GeneralBroydenCache) = cache.fu
6161

6262
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
6363
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
64-
termination_condition = nothing, internalnorm = DEFAULT_NORM,
65-
kwargs...) where {uType, iip}
64+
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
65+
kwargs...) where {uType, iip, F}
6666
@unpack f, u0, p = prob
6767
u = alias_u0 ? u0 : deepcopy(u0)
6868
fu = evaluate_f(prob, u)
@@ -71,10 +71,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
7171
alg.reset_tolerance
7272
reset_check = x -> abs(x) reset_tolerance
7373

74-
abstol, reltol, termination_condition = _init_termination_elements(abstol,
75-
reltol,
76-
termination_condition,
77-
eltype(u))
74+
abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol,
75+
termination_condition, eltype(u))
7876

7977
mode = DiffEqBase.get_termination_mode(termination_condition)
8078

@@ -83,8 +81,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
8381
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
8482
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
8583
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
86-
reset_tolerance,
87-
reset_check, prob, NLStats(1, 0, 0, 0, 0),
84+
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
8885
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
8986
storage)
9087
end
@@ -95,9 +92,9 @@ function perform_step!(cache::GeneralBroydenCache{true})
9592
termination_condition = cache.termination_condition(tc_storage)
9693
T = eltype(u)
9794

98-
mul!(_vec(du), J⁻¹, -_vec(fu))
95+
mul!(_vec(du), J⁻¹, _vec(fu))
9996
α = perform_linesearch!(cache.lscache, u, du)
100-
_axpy!(α, du, u)
97+
_axpy!(-α, du, u)
10198
f(fu2, u, p)
10299

103100
termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) &&
@@ -119,6 +116,7 @@ function perform_step!(cache::GeneralBroydenCache{true})
119116
J⁻¹[diagind(J⁻¹)] .= T(1)
120117
cache.resets += 1
121118
else
119+
du .*= -1
122120
mul!(_vec(J⁻¹df), J⁻¹, _vec(dfu))
123121
mul!(J⁻¹₂, _vec(du)', J⁻¹)
124122
denom = dot(du, J⁻¹df)
@@ -138,9 +136,9 @@ function perform_step!(cache::GeneralBroydenCache{false})
138136

139137
T = eltype(cache.u)
140138

141-
cache.du = _restructure(cache.du, cache.J⁻¹ * -_vec(cache.fu))
139+
cache.du = _restructure(cache.du, cache.J⁻¹ * _vec(cache.fu))
142140
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
143-
cache.u = cache.u .+ α * cache.du
141+
cache.u = cache.u .- α * cache.du
144142
cache.fu2 = f(cache.u, p)
145143

146144
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
@@ -160,6 +158,7 @@ function perform_step!(cache::GeneralBroydenCache{false})
160158
cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu)
161159
cache.resets += 1
162160
else
161+
cache.du = -cache.du
163162
cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu))
164163
cache.J⁻¹₂ = _vec(cache.du)' * cache.J⁻¹
165164
denom = dot(cache.du, cache.J⁻¹df)

src/dfsane.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ end
9797

9898
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
9999
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
100-
termination_condition = nothing, internalnorm = DEFAULT_NORM,
101-
kwargs...) where {uType, iip}
100+
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
101+
kwargs...) where {uType, iip, F}
102102
uₙ = alias_u0 ? prob.u0 : deepcopy(prob.u0)
103103

104104
p = prob.p

src/gaussnewton.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ end
4949
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
5050
precs = DEFAULT_PRECS, adkwargs...)
5151
ad = default_adargs_to_adtype(; adkwargs...)
52-
return GaussNewton{_unwrap_val(concrete_jac)}(ad,
53-
linsolve,
54-
precs)
52+
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
5553
end
5654

5755
@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
@@ -84,21 +82,15 @@ end
8482

8583
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
8684
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
87-
termination_condition = nothing,
88-
internalnorm = DEFAULT_NORM,
89-
kwargs...) where {uType, iip}
85+
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
86+
kwargs...) where {uType, iip, F}
9087
alg = get_concrete_algorithm(alg_, prob)
9188
@unpack f, u0, p = prob
9289

9390
linsolve_with_JᵀJ = Val(_needs_square_A(alg, u0))
9491

9592
u = alias_u0 ? u0 : deepcopy(u0)
96-
if iip
97-
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
98-
f(fu1, u, p)
99-
else
100-
fu1 = f(u, p)
101-
end
93+
fu1 = evaluate_f(prob, u)
10294

10395
if SciMLBase._unwrap_val(linsolve_with_JᵀJ)
10496
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p,

src/klement.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ get_fu(cache::GeneralKlementCache) = cache.fu
7070

7171
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKlement, args...;
7272
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
73-
termination_condition = nothing, internalnorm = DEFAULT_NORM,
74-
linsolve_kwargs = (;), kwargs...) where {uType, iip}
73+
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
74+
linsolve_kwargs = (;), kwargs...) where {uType, iip, F}
7575
@unpack f, u0, p = prob
7676
u = alias_u0 ? u0 : deepcopy(u0)
7777
fu = evaluate_f(prob, u)
@@ -89,10 +89,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
8989
linsolve = __setup_linsolve(J, _vec(fu), _vec(du), p, alg)
9090
end
9191

92-
abstol, reltol, termination_condition = _init_termination_elements(abstol,
93-
reltol,
94-
termination_condition,
95-
eltype(u))
92+
abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol,
93+
termination_condition, eltype(u))
9694

9795
mode = DiffEqBase.get_termination_mode(termination_condition)
9896

@@ -129,12 +127,12 @@ function perform_step!(cache::GeneralKlementCache{true})
129127

130128
# u = u - J \ fu
131129
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),
132-
b = -_vec(fu), linu = _vec(du), p, reltol = cache.abstol)
130+
b = _vec(fu), linu = _vec(du), p, reltol = cache.abstol)
133131
cache.linsolve = linres.cache
134132

135133
# Line Search
136134
α = perform_linesearch!(cache.lscache, u, du)
137-
_axpy!(α, du, u)
135+
_axpy!(-α, du, u)
138136
f(cache.fu2, u, p)
139137

140138
termination_condition(cache.fu2, u, u_prev, cache.abstol, cache.reltol) &&
@@ -146,6 +144,7 @@ function perform_step!(cache::GeneralKlementCache{true})
146144
cache.force_stop && return nothing
147145

148146
# Update the Jacobian
147+
cache.du .*= -1
149148
cache.J_cache .= cache.J' .^ 2
150149
cache.Jdu .= _vec(du) .^ 2
151150
mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu)
@@ -186,29 +185,30 @@ function perform_step!(cache::GeneralKlementCache{false})
186185

187186
# u = u - J \ fu
188187
if linsolve === nothing
189-
cache.du = -fu / cache.J
188+
cache.du = fu / cache.J
190189
else
191190
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),
192-
b = -_vec(fu), linu = _vec(cache.du), p, reltol = cache.abstol)
191+
b = _vec(fu), linu = _vec(cache.du), p, reltol = cache.abstol)
193192
cache.linsolve = linres.cache
194193
end
195194

196195
# Line Search
197196
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
198-
cache.u = @. cache.u + α * cache.du # `u` might not support mutation
197+
cache.u = @. cache.u - α * cache.du # `u` might not support mutation
199198
cache.fu2 = f(cache.u, p)
200199

201200
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
202201
(cache.force_stop = true)
203202

204-
cache.u_prev = @. cache.u
203+
cache.u_prev = cache.u
205204
cache.stats.nf += 1
206205
cache.stats.nsolve += 1
207206
cache.stats.nfactors += 1
208207

209208
cache.force_stop && return nothing
210209

211210
# Update the Jacobian
211+
cache.du = -cache.du
212212
cache.J_cache = cache.J' .^ 2
213213
cache.Jdu = _vec(cache.du) .^ 2
214214
cache.Jᵀ²du = cache.J_cache * cache.Jdu

src/lbroyden.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ get_fu(cache::LimitedMemoryBroydenCache) = cache.fu
6868

6969
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LimitedMemoryBroyden,
7070
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
71-
termination_condition = nothing, internalnorm = DEFAULT_NORM,
72-
kwargs...) where {uType, iip}
71+
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
72+
kwargs...) where {uType, iip, F}
7373
@unpack f, u0, p = prob
7474
u = alias_u0 ? u0 : deepcopy(u0)
7575
if u isa Number
@@ -81,15 +81,13 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LimitedMemory
8181
fu = evaluate_f(prob, u)
8282
threshold = min(alg.threshold, maxiters)
8383
U, Vᵀ = __init_low_rank_jacobian(u, fu, threshold)
84-
du = -fu
84+
du = copy(fu)
8585
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) :
8686
alg.reset_tolerance
8787
reset_check = x -> abs(x) reset_tolerance
8888

89-
abstol, reltol, termination_condition = _init_termination_elements(abstol,
90-
reltol,
91-
termination_condition,
92-
eltype(u))
89+
abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol,
90+
termination_condition, eltype(u))
9391

9492
mode = DiffEqBase.get_termination_mode(termination_condition)
9593

@@ -112,7 +110,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
112110
termination_condition = cache.termination_condition(tc_storage)
113111

114112
α = perform_linesearch!(cache.lscache, u, du)
115-
_axpy!(α, du, u)
113+
_axpy!(-α, du, u)
116114
f(cache.fu2, u, p)
117115

118116
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
@@ -134,7 +132,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
134132
end
135133
cache.iterations_since_reset = 0
136134
cache.resets += 1
137-
cache.du .= -cache.fu
135+
cache.du .= cache.fu
138136
else
139137
idx = min(cache.iterations_since_reset, size(cache.U, 1))
140138
U_part = selectdim(cache.U, 1, 1:idx)
@@ -154,7 +152,6 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
154152
U_part = selectdim(cache.U, 1, 1:idx)
155153
Vᵀ_part = selectdim(cache.Vᵀ, 2, 1:idx)
156154
__lbroyden_matvec!(_vec(cache.du), cache.Ux, U_part, Vᵀ_part, _vec(cache.fu2))
157-
cache.du .*= -1
158155
cache.iterations_since_reset += 1
159156
end
160157

@@ -172,7 +169,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false})
172169
T = eltype(cache.u)
173170

174171
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
175-
cache.u = cache.u .+ α * cache.du
172+
cache.u = cache.u .- α * cache.du
176173
cache.fu2 = f(cache.u, p)
177174

178175
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
@@ -194,7 +191,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false})
194191
end
195192
cache.iterations_since_reset = 0
196193
cache.resets += 1
197-
cache.du = -cache.fu
194+
cache.du = cache.fu
198195
else
199196
idx = min(cache.iterations_since_reset, size(cache.U, 1))
200197
U_part = selectdim(cache.U, 1, 1:idx)
@@ -215,7 +212,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false})
215212
U_part = selectdim(cache.U, 1, 1:idx)
216213
Vᵀ_part = selectdim(cache.Vᵀ, 2, 1:idx)
217214
cache.du = _restructure(cache.du,
218-
-__lbroyden_matvec(U_part, Vᵀ_part, _vec(cache.fu2)))
215+
__lbroyden_matvec(U_part, Vᵀ_part, _vec(cache.fu2)))
219216
cache.iterations_since_reset += 1
220217
end
221218

src/levenberg.jl

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,8 @@ end
163163
function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
164164
NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt,
165165
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
166-
termination_condition = nothing,
167-
internalnorm = DEFAULT_NORM,
168-
linsolve_kwargs = (;), kwargs...) where {uType, iip}
166+
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
167+
linsolve_kwargs = (;), kwargs...) where {uType, iip, F}
169168
alg = get_concrete_algorithm(alg_, prob)
170169
@unpack f, u0, p = prob
171170
u = alias_u0 ? u0 : deepcopy(u0)
@@ -231,10 +230,8 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
231230
end
232231

233232
return LevenbergMarquardtCache{iip, !_unwrap_val(linsolve_with_JᵀJ)}(f, alg, u, copy(u),
234-
fu1,
235-
fu2, du, p, uf, linsolve, J,
236-
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
237-
DᵀD,
233+
fu1, fu2, du, p, uf, linsolve, J, jac_cache, false, maxiters, internalnorm,
234+
ReturnCode.Default, abstol, reltol, prob, DᵀD,
238235
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
239236
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
240237
zero(u), zero(fu1), mat_tmp, rhs_tmp, J², NLStats(1, 0, 0, 0, 0),
@@ -321,11 +318,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastls}) where {fast
321318
if (1 - β)^b_uphill * loss loss_old
322319
# Accept step.
323320
cache.u .+= δ
324-
if termination_condition(cache.fu_tmp,
325-
cache.u,
326-
u_prev,
327-
cache.abstol,
328-
cache.reltol)
321+
if termination_condition(cache.fu_tmp, u, u_prev, cache.abstol, cache.reltol)
329322
cache.force_stop = true
330323
return nothing
331324
end

src/linesearch.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -228,21 +228,21 @@ function perform_linesearch!(cache::LiFukushimaLineSearchCache{iip}, u, du) wher
228228

229229
# Early Terminate based on Eq. 2.7
230230
if iip
231-
cache.u_cache .= u .+ du
231+
cache.u_cache .= u .- du
232232
cache.f(cache.fu_cache, cache.u_cache, cache.p)
233233
fxλ_norm = norm(cache.fu_cache, 2)
234234
else
235-
fxλ_norm = norm(cache.f(u .+ du, cache.p), 2)
235+
fxλ_norm = norm(cache.f(u .- du, cache.p), 2)
236236
end
237237

238238
fxλ_norm ρ * fx_norm - σ₂ * norm(du, 2)^2 && return cache.α
239239

240240
if iip
241-
cache.u_cache .= u .+ λ₂ .* du
241+
cache.u_cache .= u .- λ₂ .* du
242242
cache.f(cache.fu_cache, cache.u_cache, cache.p)
243243
fxλp_norm = norm(cache.fu_cache, 2)
244244
else
245-
fxλp_norm = norm(cache.f(u .+ λ₂ .* du, cache.p), 2)
245+
fxλp_norm = norm(cache.f(u .- λ₂ .* du, cache.p), 2)
246246
end
247247

248248
if !isfinite(fxλp_norm)
@@ -252,11 +252,11 @@ function perform_linesearch!(cache::LiFukushimaLineSearchCache{iip}, u, du) wher
252252
λ₁, λ₂ = λ₂, β * λ₂
253253

254254
if iip
255-
cache.u_cache .= u .+ λ₂ .* du
255+
cache.u_cache .= u .- λ₂ .* du
256256
cache.f(cache.fu_cache, cache.u_cache, cache.p)
257257
fxλp_norm = norm(cache.fu_cache, 2)
258258
else
259-
fxλp_norm = norm(cache.f(u .+ λ₂ .* du, cache.p), 2)
259+
fxλp_norm = norm(cache.f(u .- λ₂ .* du, cache.p), 2)
260260
end
261261

262262
nan_converged = isfinite(fxλp_norm)
@@ -269,11 +269,11 @@ function perform_linesearch!(cache::LiFukushimaLineSearchCache{iip}, u, du) wher
269269

270270
for _ in 1:maxiters
271271
if iip
272-
cache.u_cache .= u .+ λ₂ .* du
272+
cache.u_cache .= u .- λ₂ .* du
273273
cache.f(cache.fu_cache, cache.u_cache, cache.p)
274274
fxλp_norm = norm(cache.fu_cache, 2)
275275
else
276-
fxλp_norm = norm(cache.f(u .+ λ₂ .* du, cache.p), 2)
276+
fxλp_norm = norm(cache.f(u .- λ₂ .* du, cache.p), 2)
277277
end
278278

279279
converged = fxλp_norm (1 + η) * fx_norm - σ₁ * λ₂^2 * norm(du, 2)^2

src/pseudotransient.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransi
8686

8787
@unpack f, u0, p = prob
8888
u = alias_u0 ? u0 : deepcopy(u0)
89-
if iip
90-
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
91-
f(fu1, u, p)
92-
else
93-
fu1 = _mutable(f(u, p))
94-
end
89+
fu1 = evaluate_f(prob, u)
9590
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
9691
linsolve_kwargs)
9792
alpha = convert(eltype(u), alg.alpha_initial)

0 commit comments

Comments
 (0)