Skip to content

Commit 2a1d570

Browse files
authored
Merge pull request #311 from avik-pal/ap/tr_nlls
2 parents b0c28f1 + 26082c7 commit 2a1d570

File tree

6 files changed

+61
-31
lines changed

6 files changed

+61
-31
lines changed

docs/src/api/nonlinearsolve.md

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,38 @@
22

33
These are the native solvers of NonlinearSolve.jl.
44

5-
## Core Nonlinear Solvers
5+
## Nonlinear Solvers
66

77
```@docs
88
NewtonRaphson
9-
TrustRegion
109
PseudoTransient
1110
DFSane
1211
Broyden
1312
Klement
1413
```
1514

16-
## Polyalgorithms
15+
## Nonlinear Least Squares Solvers
1716

1817
```@docs
19-
NonlinearSolvePolyAlgorithm
20-
FastShortcutNonlinearPolyalg
21-
FastShortcutNLLSPolyalg
22-
RobustMultiNewton
18+
GaussNewton
2319
```
2420

25-
## Nonlinear Least Squares Solvers
21+
## Both Nonlinear & Nonlinear Least Squares Solvers
22+
23+
These solvers can be used for both nonlinear and nonlinear least squares problems.
2624

2725
```@docs
26+
TrustRegion
2827
LevenbergMarquardt
29-
GaussNewton
28+
```
29+
30+
## Polyalgorithms
31+
32+
```@docs
33+
NonlinearSolvePolyAlgorithm
34+
FastShortcutNonlinearPolyalg
35+
FastShortcutNLLSPolyalg
36+
RobustMultiNewton
3037
```
3138

3239
## Radius Update Schemes for Trust Region (RadiusUpdateSchemes)

docs/src/solvers/NonlinearLeastSquaresSolvers.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ falls back to a more robust algorithm (`LevenbergMarquardt`).
2323
handling of sparse matrices via colored automatic differentiation and preconditioned
2424
linear solvers. Designed for large-scale and numerically-difficult nonlinear least
2525
squares problems.
26+
- `TrustRegion()`: A Newton Trust Region dogleg method with swappable nonlinear solvers and
27+
autodiff methods for high performance on large and sparse systems.
2628

2729
### SimpleNonlinearSolve.jl
2830

src/jacobian.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,30 +213,43 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
213213
end
214214

215215
# jvp fallback scalar
216-
function __jacvec(uf, u; autodiff, kwargs...)
217-
if !(autodiff isa AutoForwardDiff || autodiff isa AutoFiniteDiff)
216+
function __gradient_operator(uf, u; autodiff, kwargs...)
217+
if !(autodiff isa AutoFiniteDiff || autodiff isa AutoZygote)
218218
_ad = autodiff
219-
autodiff = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(),
219+
number_ad = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(),
220220
AutoFiniteDiff())
221-
@warn "$(_ad) not supported for JacVec. Using $(autodiff) instead."
221+
if u isa Number
222+
autodiff = number_ad
223+
else
224+
if isinplace(uf)
225+
autodiff = AutoFiniteDiff()
226+
else
227+
autodiff = ifelse(is_extension_loaded(Val{:Zygote}()), AutoZygote(),
228+
AutoFiniteDiff())
229+
end
230+
end
231+
if _ad !== nothing && _ad !== autodiff
232+
@warn "$(_ad) not supported for VecJac. Using $(autodiff) instead."
233+
end
222234
end
223-
return u isa Number ? JVPScalar(uf, u, autodiff) : JacVec(uf, u; autodiff, kwargs...)
235+
return u isa Number ? GradientScalar(uf, u, autodiff) :
236+
VecJac(uf, u; autodiff, kwargs...)
224237
end
225238

226-
@concrete mutable struct JVPScalar
239+
@concrete mutable struct GradientScalar
227240
uf
228241
u
229242
autodiff
230243
end
231244

232-
function Base.:*(jvp::JVPScalar, v::Number)
245+
function Base.:*(jvp::GradientScalar, v::Number)
233246
if jvp.autodiff isa AutoForwardDiff
234247
T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u)))
235-
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v))
248+
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, one(v)))
236249
return ForwardDiff.extract_derivative(T, out)
237250
elseif jvp.autodiff isa AutoFiniteDiff
238251
J = FiniteDiff.finite_difference_derivative(jvp.uf, jvp.u, jvp.autodiff.fdtype)
239-
return J * v
252+
return J
240253
else
241254
error("Only ForwardDiff & FiniteDiff is currently supported.")
242255
end

src/trace.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ end
6060
6161
## Arguments
6262
63-
- `freq`: Sets both `print_frequency` and `store_frequency` to `freq`.
63+
- `freq`: Sets both `print_frequency` and `store_frequency` to `freq`.
6464
6565
## Keyword Arguments
6666
67-
- `print_frequency`: Print the trace every `print_frequency` iterations if
67+
- `print_frequency`: Print the trace every `print_frequency` iterations if
6868
`show_trace == Val(true)`.
69-
- `store_frequency`: Store the trace every `store_frequency` iterations if
69+
- `store_frequency`: Store the trace every `store_frequency` iterations if
7070
`store_trace == Val(true)`.
7171
"""
7272
@kwdef struct TraceAll <: AbstractNonlinearSolveTraceLevel

src/trustRegion.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,14 @@ end
247247
p3
248248
p4
249249
ϵ
250-
jvp_operator # For Yuan
250+
vjp_operator # For Yuan
251251
stats::NLStats
252252
tc_cache
253253
trace
254254
end
255255

256-
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, args...;
256+
function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
257+
NonlinearLeastSquaresProblem{uType, iip}}, alg_::TrustRegion, args...;
257258
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
258259
termination_condition = nothing, internalnorm = DEFAULT_NORM,
259260
linsolve_kwargs = (;), kwargs...) where {uType, iip}
@@ -317,7 +318,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
317318
p3 = convert(floatType, 0.0)
318319
p4 = convert(floatType, 0.0)
319320
ϵ = convert(floatType, 1.0e-8)
320-
jvp_operator = nothing
321+
vjp_operator = nothing
321322
if radius_update_scheme === RadiusUpdateSchemes.NLsolve
322323
p1 = convert(floatType, 0.5)
323324
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
@@ -336,8 +337,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
336337
p1 = convert(floatType, 2.0) # μ
337338
p2 = convert(floatType, 1 / 6) # c5
338339
p3 = convert(floatType, 6.0) # c6
339-
jvp_operator = __jacvec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
340-
@bb Jᵀf = jvp_operator × fu
340+
vjp_operator = __gradient_operator(uf, u; fu,
341+
autodiff = __get_nonsparse_ad(alg.vjp_autodiff))
342+
@bb Jᵀf = vjp_operator × fu
341343
initial_trust_radius = convert(trustType, p1 * internalnorm(Jᵀf))
342344
elseif radius_update_scheme === RadiusUpdateSchemes.Fan
343345
step_threshold = convert(trustType, 0.0001)
@@ -366,7 +368,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
366368
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
367369
radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold,
368370
shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new,
369-
shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, jvp_operator,
371+
shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, vjp_operator,
370372
NLStats(1, 0, 0, 0, 0), tc_cache, trace)
371373
end
372374

@@ -479,7 +481,7 @@ function trust_region_step!(cache::TrustRegionCache)
479481
cache.shrink_counter = 0
480482
end
481483

482-
@bb cache.Jᵀf = cache.jvp_operator × vec(cache.fu)
484+
@bb cache.Jᵀf = cache.vjp_operator × vec(cache.fu)
483485
cache.trust_r = cache.p1 * cache.internalnorm(cache.Jᵀf)
484486

485487
cache.internalnorm(cache.Jᵀf) < cache.ϵ && (cache.force_stop = true)
@@ -567,10 +569,10 @@ end
567569

568570
# FIXME: Reinit `JᵀJ` operator if `p` is changed
569571
function __reinit_internal!(cache::TrustRegionCache; kwargs...)
570-
if cache.jvp_operator !== nothing
571-
cache.jvp_operator = __jacvec(cache.uf, cache.u; cache.fu,
572+
if cache.vjp_operator !== nothing
573+
cache.vjp_operator = __gradient_operator(cache.uf, cache.u; cache.fu,
572574
autodiff = __get_nonsparse_ad(cache.alg.ad))
573-
@bb cache.Jᵀf = cache.jvp_operator × cache.fu
575+
@bb cache.Jᵀf = cache.vjp_operator × cache.fu
574576
end
575577
cache.loss = __trust_region_loss(cache, cache.fu)
576578
cache.loss_new = cache.loss

test/nonlinear_least_squares.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
2828
resid_prototype = zero(y_target)), θ_init, x)
2929

3030
nlls_problems = [prob_oop, prob_iip]
31+
3132
solvers = []
3233
for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES()]
3334
vjp_autodiffs = linsolve isa KrylovJL ? [nothing, AutoZygote(), AutoFiniteDiff()] :
@@ -46,6 +47,11 @@ append!(solvers,
4647
LeastSquaresOptimJL(:dogleg),
4748
nothing,
4849
])
50+
for radius_update_scheme in [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.NocedalWright,
51+
RadiusUpdateSchemes.NLsolve, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan,
52+
RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin]
53+
push!(solvers, TrustRegion(; radius_update_scheme))
54+
end
4955

5056
for prob in nlls_problems, solver in solvers
5157
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)

0 commit comments

Comments
 (0)