Skip to content

Commit 59a8713

Browse files
committed
Krylov Version for Trust Region
1 parent 0026bc1 commit 59a8713

File tree

4 files changed

+106
-28
lines changed

4 files changed

+106
-28
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "2.8.2"
4+
version = "2.9.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/gaussnewton.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@ An advanced GaussNewton implementation with support for efficient handling of sp
66
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
77
for large-scale and numerically-difficult nonlinear least squares problems.
88
9-
!!! note
10-
In most practical situations, users should prefer using `LevenbergMarquardt` instead! It
11-
is a more general extension of `Gauss-Newton` Method.
12-
139
### Keyword Arguments
1410
1511
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
@@ -33,11 +29,6 @@ for large-scale and numerically-difficult nonlinear least squares problems.
3329
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
3430
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
3531
used here directly, and they will be converted to the correct `LineSearch`.
36-
37-
!!! warning
38-
39-
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
40-
construction. This will be fixed in the near future.
4132
"""
4233
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
4334
ad::AD

src/jacobian.jl

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
5454
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
5555
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
5656
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
57-
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
57+
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) # || needsJᵀJ)
5858
sd = sparsity_detection_alg(f, alg.ad)
5959
ad = alg.ad
6060
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
@@ -92,9 +92,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
9292
du = _mutable_zero(u)
9393

9494
if needsJᵀJ
95-
JᵀJ = __init_JᵀJ(J)
96-
# FIXME: This needs to be handled better for JacVec Operator
97-
Jᵀfu = J' * _vec(fu)
95+
# TODO: Pass in `jac_transpose_autodiff`
96+
JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u;
97+
jac_autodiff = __get_nonsparse_ad(alg.ad))
9898
end
9999

100100
if linsolve_init
@@ -120,21 +120,68 @@ function __setup_linsolve(A, b, u, p, alg)
120120
nothing)..., weight)
121121
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
122122
end
123+
__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg)
123124

124125
__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
125126
__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
126127
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
127128
__get_nonsparse_ad(ad) = ad
128129

129-
__init_JᵀJ(J::Number) = zero(J)
130-
__init_JᵀJ(J::AbstractArray) = J' * J
131-
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
130+
__init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J)
131+
function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...)
132+
JᵀJ = J' * J
133+
Jᵀfu = J' * fu
134+
return JᵀJ, Jᵀfu
135+
end
136+
function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...)
137+
JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
138+
return JᵀJ, J' * fu
139+
end
140+
function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...;
141+
jac_transpose_autodiff = nothing, jac_autodiff = nothing, kwargs...)
142+
autodiff = __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf)
143+
Jᵀ = VecJac(uf, u; autodiff)
144+
JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u)
145+
JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ)
146+
Jᵀfu = Jᵀ * fu
147+
return JᵀJ, Jᵀfu
148+
end
149+
150+
@concrete struct KrylovJᵀJ
151+
JᵀJ
152+
Jᵀ
153+
end
154+
155+
SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)
156+
157+
function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf)
158+
if jac_transpose_autodiff === nothing
159+
if isinplace(uf)
160+
# VecJac can be only FiniteDiff
161+
return AutoFiniteDiff()
162+
else
163+
# Short circuit if we see that FiniteDiff was used for J computation
164+
jac_autodiff isa AutoFiniteDiff && return jac_autodiff
165+
# Check if Zygote is loaded then use Zygote else use FiniteDiff
166+
if haskey(Base.loaded_modules,
167+
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
168+
return AutoZygote()
169+
else
170+
return AutoFiniteDiff()
171+
end
172+
end
173+
else
174+
return __get_nonsparse_ad(jac_transpose_autodiff)
175+
end
176+
end
132177

133178
__maybe_symmetric(x) = Symmetric(x)
134179
__maybe_symmetric(x::Number) = x
135180
# LinearSolve with `nothing` doesn't dispatch correctly here
136181
__maybe_symmetric(x::StaticArray) = x
137182
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
183+
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x
184+
__maybe_symmetric(x::KrylovJᵀJ) = x
138185

139186
## Special Handling for Scalars
140187
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
@@ -145,3 +192,37 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
145192
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
146193
return uf, nothing, u, nothing, nothing, u
147194
end
195+
196+
function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
197+
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
198+
end
199+
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, _, J) = setproperty!(cache, sym, J' * J)
200+
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, _, J) = mul!(getproperty(cache, sym), J', J)
201+
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
202+
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
203+
204+
function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
205+
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)
206+
end
207+
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
208+
return setproperty!(cache, sym1, J' * fu)
209+
end
210+
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
211+
return mul!(getproperty(cache, sym1), J', fu)
212+
end
213+
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
214+
return setproperty!(cache, sym1, H.Jᵀ * fu)
215+
end
216+
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
217+
return mul!(getproperty(cache, sym1), H.Jᵀ, fu)
218+
end
219+
220+
# Left-Right Multiplication
221+
__lr_mul(::Val, H, g) = dot(g, H, g)
222+
## TODO: Use a cache here to avoid allocations
223+
__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g)
224+
function __lr_mul(::Val{true}, H::KrylovJᵀJ, g)
225+
c = similar(g)
226+
mul!(c, H.JᵀJ, g)
227+
return dot(g, c)
228+
end

src/trustRegion.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -239,15 +239,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
239239
fu_prev = zero(fu1)
240240

241241
loss = get_loss(fu1)
242-
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
243-
linsolve_kwargs)
242+
# uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
243+
# linsolve_kwargs)
244+
uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip);
245+
linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false))
246+
linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, u, p, alg)
247+
244248
u_tmp = zero(u)
245249
u_cauchy = zero(u)
246250
u_gauss_newton = _mutable_zero(u)
247251

248252
loss_new = loss
249-
H = zero(J' * J)
250-
g = _mutable_zero(fu1)
253+
# H = zero(J' * J)
254+
# g = _mutable_zero(fu1)
251255
shrink_counter = 0
252256
fu_new = zero(fu1)
253257
make_new_J = true
@@ -346,8 +350,10 @@ function perform_step!(cache::TrustRegionCache{true})
346350
@unpack make_new_J, J, fu, f, u, p, u_gauss_newton, alg, linsolve = cache
347351
if cache.make_new_J
348352
jacobian!!(J, cache)
349-
mul!(cache.H, J', J)
350-
mul!(_vec(cache.g), J', _vec(fu))
353+
__update_JᵀJ!(Val{true}(), cache, :H, J)
354+
# mul!(cache.H, J', J)
355+
__update_Jᵀf!(Val{true}(), cache, :g, :H, J, fu)
356+
# mul!(_vec(cache.g), J', _vec(fu))
351357
cache.stats.njacs += 1
352358

353359
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
@@ -376,8 +382,8 @@ function perform_step!(cache::TrustRegionCache{false})
376382

377383
if make_new_J
378384
J = jacobian!!(cache.J, cache)
379-
cache.H = J' * J
380-
cache.g = _restructure(fu, J' * _vec(fu))
385+
__update_JᵀJ!(Val{false}(), cache, :H, J)
386+
__update_Jᵀf!(Val{false}(), cache, :g, :H, J, fu)
381387
cache.stats.njacs += 1
382388

383389
if cache.linsolve === nothing
@@ -431,7 +437,7 @@ function trust_region_step!(cache::TrustRegionCache)
431437

432438
# Compute the ratio of the actual reduction to the predicted reduction.
433439
cache.r = -(loss - cache.loss_new) /
434-
(dot(_vec(du), _vec(g)) + dot(_vec(du), H, _vec(du)) / 2)
440+
(dot(_vec(du), _vec(g)) + __lr_mul(Val(isinplace(cache)), H, _vec(du)) / 2)
435441
@unpack r = cache
436442

437443
if radius_update_scheme === RadiusUpdateSchemes.Simple
@@ -594,7 +600,7 @@ function dogleg!(cache::TrustRegionCache{true})
594600

595601
# Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
596602
l_grad = norm(cache.g) # length of the gradient
597-
d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
603+
d_cauchy = l_grad^3 / __lr_mul(Val{true}(), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
598604
if d_cauchy >= trust_r
599605
@. cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region
600606
return
@@ -624,7 +630,7 @@ function dogleg!(cache::TrustRegionCache{false})
624630

625631
## Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
626632
l_grad = norm(cache.g)
627-
d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
633+
d_cauchy = l_grad^3 / __lr_mul(Val{false}(), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
628634
if d_cauchy > trust_r # cauchy point lies outside of trust region
629635
cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region
630636
return

0 commit comments

Comments
 (0)