Skip to content

Commit 7e26d18

Browse files
committed
Add support for line search in Newton Raphson
1 parent de8086c commit 7e26d18

File tree

9 files changed

+241
-60
lines changed

9 files changed

+241
-60
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1111
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1212
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1313
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
14+
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1617
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
@@ -33,6 +34,7 @@ Enzyme = "0.11"
3334
FiniteDiff = "2"
3435
ForwardDiff = "0.10.3"
3536
LinearSolve = "2"
37+
LineSearches = "7"
3638
PrecompileTools = "1"
3739
RecursiveArrayTools = "2"
3840
Reexport = "0.2, 1"

src/NonlinearSolve.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isi
2020
import StaticArraysCore: StaticArray, SVector, SArray, MArray
2121
import UnPack: @unpack
2222

23-
@reexport using ADTypes, SciMLBase, SimpleNonlinearSolve
23+
@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
2424

2525
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
2626
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
@@ -35,6 +35,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAl
3535
end
3636

3737
include("utils.jl")
38+
include("linesearch.jl")
3839
include("raphson.jl")
3940
include("trustRegion.jl")
4041
include("levenberg.jl")
@@ -69,4 +70,6 @@ export RadiusUpdateSchemes
6970

7071
export NewtonRaphson, TrustRegion, LevenbergMarquardt
7172

73+
export LineSearch
74+
7275
end # module

src/jacobian.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99
(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p)))
1010
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)
1111

12-
sparsity_detection_alg(f, ad) = NoSparsityDetection()
12+
sparsity_detection_alg(_, _) = NoSparsityDetection()
1313
function sparsity_detection_alg(f, ad::AbstractSparseADType)
1414
if f.sparsity === nothing
1515
if f.jac_prototype === nothing
@@ -49,8 +49,8 @@ end
4949
jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))
5050

5151
# Build Jacobian Caches
52-
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
53-
::Val{iip}) where {iip}
52+
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{iip};
53+
linsolve_kwargs=(;)) where {iip}
5454
uf = JacobianWrapper{iip}(f, p)
5555

5656
haslinsolve = hasfield(typeof(alg), :linsolve)
@@ -92,14 +92,15 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
9292

9393
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
9494
nothing)..., weight)
95-
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
95+
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
96+
linsolve_kwargs...)
9697

9798
return uf, linsolve, J, fu, jac_cache, du
9899
end
99100

100101
## Special Handling for Scalars
101102
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
102-
::Val{false})
103+
::Val{false}; kwargs...)
103104
# NOTE: Scalar `u` assumes scalar output from `f`
104105
uf = JacobianWrapper{false}(f, p)
105106
return uf, nothing, u, nothing, nothing, u

src/levenberg.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,12 @@ isinplace(::LevenbergMarquardtCache{iip}) where {iip} = iip
142142

143143
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt,
144144
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
145-
kwargs...) where {uType, iip}
145+
linsolve_kwargs=(;), kwargs...) where {uType, iip}
146146
@unpack f, u0, p = prob
147147
u = alias_u0 ? u0 : deepcopy(u0)
148-
if iip
149-
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
150-
f(fu1, u, p)
151-
else
152-
fu1 = f(u, p)
153-
end
154-
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))
148+
fu1 = evaluate_f(prob, u)
149+
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
150+
linsolve_kwargs)
155151

156152
λ = convert(eltype(u), alg.damping_initial)
157153
λ_factor = convert(eltype(u), alg.damping_increase_factor)

src/linesearch.jl

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""
2+
LineSearch(method = Static(), autodiff = AutoFiniteDiff(), alpha = true)
3+
4+
Wrapper over algorithms from
5+
[LineSeaches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl/). Allows automatic
6+
construction of the objective functions for the line search algorithms utilizing automatic
7+
differentiation for fast Vector Jacobian Products.
8+
9+
### Arguments
10+
11+
- `method`: the line search algorithm to use. Defaults to `Static()`, which means that the
12+
step size is fixed to the value of `alpha`.
13+
- `autodiff`: the automatic differentiation backend to use for the line search. Defaults to
14+
`AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP.
15+
`AutoZygote()` will be faster in most cases, but it requires `Zygote.jl` to be manually
16+
installed and loaded
17+
- `alpha`: the initial step size to use. Defaults to `true` (which is equivalent to `1`).
18+
"""
19+
@concrete struct LineSearch
20+
method
21+
autodiff
22+
α
23+
end
24+
25+
function LineSearch(; method = Static(), autodiff = AutoFiniteDiff(), alpha = true)
26+
return LineSearch(method, autodiff, alpha)
27+
end
28+
29+
@concrete mutable struct LineSearchCache
30+
f
31+
ϕ
32+
33+
ϕdϕ
34+
α
35+
ls
36+
end
37+
38+
function LineSearchCache(ls::LineSearch, f, u::Number, p, _, ::Val{false})
39+
eval_f(u, du, α) = eval_f(u - α * du)
40+
eval_f(u) = f(u, p)
41+
42+
ls.method isa Static && return LineSearchCache(eval_f, nothing, nothing, nothing,
43+
convert(typeof(u), ls.α), ls)
44+
45+
g(u, fu) = last(value_derivative(Base.Fix2(f, p), u)) * fu
46+
47+
function ϕ(u, du)
48+
function ϕ_internal(α)
49+
u_ = u - α * du
50+
_fu = eval_f(u_)
51+
return dot(_fu, _fu) / 2
52+
end
53+
return ϕ_internal
54+
end
55+
56+
function (u, du)
57+
function dϕ_internal(α)
58+
u_ = u - α * du
59+
_fu = eval_f(u_)
60+
g₀ = g(u_, _fu)
61+
return dot(g₀, -du)
62+
end
63+
return dϕ_internal
64+
end
65+
66+
function ϕdϕ(u, du)
67+
function ϕdϕ_internal(α)
68+
u_ = u - α * du
69+
_fu = eval_f(u_)
70+
g₀ = g(u_, _fu)
71+
return dot(_fu, _fu) / 2, dot(g₀, -du)
72+
end
73+
return ϕdϕ_internal
74+
end
75+
76+
return LineSearchCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
77+
end
78+
79+
function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip}
80+
fu = iip ? fu1 : nothing
81+
u_ = _mutable_zero(u)
82+
83+
function eval_f(u, du, α)
84+
@. u_ = u - α * du
85+
return eval_f(u_)
86+
end
87+
eval_f(u) = evaluate_f(f, u, p, IIP; fu)
88+
89+
ls.method isa Static && return LineSearchCache(eval_f, nothing, nothing, nothing,
90+
convert(eltype(u), ls.α), ls)
91+
92+
g₀ = _mutable_zero(u)
93+
94+
function g!(u, fu)
95+
op = VecJac((args...) -> f(args..., p), u)
96+
if iip
97+
mul!(g₀, op, fu)
98+
return g₀
99+
else
100+
return op * fu
101+
end
102+
end
103+
104+
function ϕ(u, du)
105+
function ϕ_internal(α)
106+
@. u_ = u - α * du
107+
_fu = eval_f(u_)
108+
return dot(_fu, _fu) / 2
109+
end
110+
return ϕ_internal
111+
end
112+
113+
function (u, du)
114+
function dϕ_internal(α)
115+
@. u_ = u - α * du
116+
_fu = eval_f(u_)
117+
g₀ = g!(u_, _fu)
118+
return dot(g₀, -du)
119+
end
120+
return dϕ_internal
121+
end
122+
123+
function ϕdϕ(u, du)
124+
function ϕdϕ_internal(α)
125+
@. u_ = u - α * du
126+
_fu = eval_f(u_)
127+
g₀ = g!(u_, _fu)
128+
return dot(_fu, _fu) / 2, dot(g₀, -du)
129+
end
130+
return ϕdϕ_internal
131+
end
132+
133+
return LineSearchCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
134+
end
135+
136+
function perform_linesearch!(cache::LineSearchCache, u, du)
137+
cache.ls.method isa Static && return (cache.α, cache.f(u, du, cache.α))
138+
139+
ϕ = cache.ϕ(u, du)
140+
= cache.(u, du)
141+
ϕdϕ = cache.ϕdϕ(u, du)
142+
143+
ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u)))
144+
145+
return cache.ls.method(ϕ, cache.(u, du), cache.ϕdϕ(u, du), cache.α, ϕ₀, dϕ₀)
146+
end

src/raphson.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,24 @@ for large-scale and numerically-difficult nonlinear systems.
2525
preconditioners. For more information on specifying preconditioners for LinearSolve
2626
algorithms, consult the
2727
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
28+
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
29+
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
30+
used here directly, and they will be converted to the correct `LineSearch`.
2831
"""
2932
@concrete struct NewtonRaphson{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
3033
ad::AD
3134
linsolve
3235
precs
36+
linesearch
3337
end
3438

3539
concrete_jac(::NewtonRaphson{CJ}) where {CJ} = CJ
3640

3741
function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
38-
precs = DEFAULT_PRECS, adkwargs...)
42+
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
3943
ad = default_adargs_to_adtype(; adkwargs...)
40-
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
44+
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method=linesearch)
45+
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
4146
end
4247

4348
@concrete mutable struct NewtonRaphsonCache{iip}
@@ -59,26 +64,23 @@ end
5964
abstol
6065
prob
6166
stats::NLStats
67+
lscache
6268
end
6369

6470
isinplace(::NewtonRaphsonCache{iip}) where {iip} = iip
6571

6672
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson, args...;
6773
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
68-
kwargs...) where {uType, iip}
74+
linsolve_kwargs=(;), kwargs...) where {uType, iip}
6975
@unpack f, u0, p = prob
7076
u = alias_u0 ? u0 : deepcopy(u0)
71-
if iip
72-
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
73-
f(fu1, u, p)
74-
else
75-
fu1 = _mutable(f(u, p))
76-
end
77-
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))
77+
fu1 = evaluate_f(prob, u)
78+
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
79+
linsolve_kwargs)
7880

7981
return NewtonRaphsonCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
8082
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
81-
NLStats(1, 0, 0, 0, 0))
83+
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)))
8284
end
8385

8486
function perform_step!(cache::NewtonRaphsonCache{true})
@@ -89,8 +91,10 @@ function perform_step!(cache::NewtonRaphsonCache{true})
8991
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
9092
p, reltol = cache.abstol)
9193
cache.linsolve = linres.cache
92-
@. u = u - du
93-
f(fu1, u, p)
94+
95+
# Line Search
96+
α, _ = perform_linesearch!(cache.lscache, u, du)
97+
@. u = u - α * du
9498

9599
cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)
96100
cache.stats.nf += 1
@@ -112,7 +116,10 @@ function perform_step!(cache::NewtonRaphsonCache{false})
112116
linu = _vec(cache.du), p, reltol = cache.abstol)
113117
cache.linsolve = linres.cache
114118
end
115-
cache.u = @. u - cache.du # `u` might not support mutation
119+
120+
# Line Search
121+
α, _fu = perform_linesearch!(cache.lscache, u, cache.du)
122+
cache.u = @. u - α * cache.du # `u` might not support mutation
116123
cache.fu1 = f(cache.u, p)
117124

118125
cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)

src/trustRegion.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,20 +202,15 @@ end
202202

203203
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, args...;
204204
alias_u0 = false, maxiters = 1000, abstol = 1e-8, internalnorm = DEFAULT_NORM,
205-
kwargs...) where {uType, iip}
205+
linsolve_kwargs=(;), kwargs...) where {uType, iip}
206206
@unpack f, u0, p = prob
207207
u = alias_u0 ? u0 : deepcopy(u0)
208208
u_prev = zero(u)
209-
if iip
210-
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
211-
f(fu1, u, p)
212-
else
213-
fu1 = f(u, p)
214-
end
209+
fu1 = evaluate_f(prob, u)
215210
fu_prev = zero(fu1)
216211

217212
loss = get_loss(fu1)
218-
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))
213+
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs)
219214

220215
radius_update_scheme = alg.radius_update_scheme
221216
max_trust_radius = convert(eltype(u), alg.max_trust_radius)

src/utils.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,26 @@ _maybe_mutable(x, ::AbstractFiniteDifferencesMode) = _mutable(x)
142142
# The shadow allocated for Enzyme needs to be mutable
143143
_maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x)
144144
_maybe_mutable(x, _) = x
145+
146+
# Helper function to get value of `f(u, p)`
147+
function evaluate_f(prob::NonlinearProblem{uType, iip}, u) where {uType, iip}
148+
@unpack f, u0, p = prob
149+
if iip
150+
fu = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
151+
f(fu, u, p)
152+
else
153+
fu = _mutable(f(u, p))
154+
end
155+
return fu
156+
end
157+
158+
evaluate_f(cache, u; fu = nothing) = evaluate_f(cache.f, u, cache.p, Val(cache.iip); fu)
159+
160+
function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip}
161+
if iip
162+
f(fu, u, p)
163+
return fu
164+
else
165+
return f(u, p)
166+
end
167+
end

0 commit comments

Comments
 (0)