Skip to content

Commit 4f1676c

Browse files
Merge pull request #268 from avik-pal/ap/gn_linesearch
Gauss Newton with Line Search
2 parents b8d43a3 + 77be3a2 commit 4f1676c

File tree

6 files changed

+30
-19
lines changed

6 files changed

+30
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ Reexport = "0.2, 1"
5858
SciMLBase = "2.4"
5959
SimpleNonlinearSolve = "0.1.23"
6060
SparseArrays = "1.9"
61-
SparseDiffTools = "2.9"
61+
SparseDiffTools = "2.11"
6262
StaticArraysCore = "1.4"
6363
UnPack = "1.0"
6464
Zygote = "0.6"

src/gaussnewton.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
2+
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
33
precs = DEFAULT_PRECS, adkwargs...)
44
55
An advanced GaussNewton implementation with support for efficient handling of sparse
@@ -30,6 +30,9 @@ for large-scale and numerically-difficult nonlinear least squares problems.
3030
preconditioners. For more information on specifying preconditioners for LinearSolve
3131
algorithms, consult the
3232
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
33+
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
34+
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
35+
used here directly, and they will be converted to the correct `LineSearch`.
3336
3437
!!! warning
3538
@@ -40,16 +43,18 @@ for large-scale and numerically-difficult nonlinear least squares problems.
4043
ad::AD
4144
linsolve
4245
precs
46+
linesearch
4347
end
4448

4549
function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
46-
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
50+
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
4751
end
4852

4953
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
50-
precs = DEFAULT_PRECS, adkwargs...)
54+
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
5155
ad = default_adargs_to_adtype(; adkwargs...)
52-
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
56+
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
57+
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
5358
end
5459

5560
@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
@@ -78,6 +83,7 @@ end
7883
stats::NLStats
7984
tc_cache_1
8085
tc_cache_2
86+
ls_cache
8187
end
8288

8389
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
@@ -107,7 +113,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
107113

108114
return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
109115
linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
110-
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2)
116+
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2,
117+
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)))
111118
end
112119

113120
function perform_step!(cache::GaussNewtonCache{true})
@@ -128,7 +135,8 @@ function perform_step!(cache::GaussNewtonCache{true})
128135
linu = _vec(du), p, reltol = cache.abstol)
129136
end
130137
cache.linsolve = linres.cache
131-
@. u = u - du
138+
α = perform_linesearch!(cache.ls_cache, u, du)
139+
_axpy!(-α, du, u)
132140
f(cache.fu_new, u, p)
133141

134142
check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
@@ -169,7 +177,8 @@ function perform_step!(cache::GaussNewtonCache{false})
169177
end
170178
cache.linsolve = linres.cache
171179
end
172-
cache.u = @. u - cache.du # `u` might not support mutation
180+
α = perform_linesearch!(cache.ls_cache, u, cache.du)
181+
cache.u = @. u - α * cache.du # `u` might not support mutation
173182
cache.fu_new = f(cache.u, p)
174183

175184
check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)

src/linesearch.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
122122
end
123123

124124
function g!(u, fu)
125-
op = VecJac((args...) -> f(args..., p), u; autodiff)
125+
op = VecJac(f, u, p; fu = fu1, autodiff)
126126
if iip
127127
mul!(g₀, op, fu)
128128
return g₀

src/raphson.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
2+
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
33
precs = DEFAULT_PRECS, adkwargs...)
44
55
An advanced NewtonRaphson implementation with support for efficient handling of sparse

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function __get_concrete_algorithm(alg, prob)
198198
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
199199
else
200200
(use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(;
201-
tag = NonlinearSolveTag())
201+
tag = ForwardDiff.Tag(NonlinearSolveTag(), eltype(prob.u0)))
202202
end
203203
return set_ad(alg, ad)
204204
end

test/nonlinear_least_squares.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
2727
resid_prototype = zero(y_target)), θ_init, x)
2828

2929
nlls_problems = [prob_oop, prob_iip]
30-
solvers = [
31-
GaussNewton(),
32-
GaussNewton(; linsolve = LUFactorization()),
33-
LevenbergMarquardt(),
34-
LevenbergMarquardt(; linsolve = LUFactorization()),
35-
LeastSquaresOptimJL(:lm),
36-
LeastSquaresOptimJL(:dogleg),
37-
]
30+
solvers = vec(Any[GaussNewton(; linsolve, linesearch)
31+
for linsolve in [nothing, LUFactorization()],
32+
linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()]])
33+
append!(solvers,
34+
[
35+
LevenbergMarquardt(),
36+
LevenbergMarquardt(; linsolve = LUFactorization()),
37+
LeastSquaresOptimJL(:lm),
38+
LeastSquaresOptimJL(:dogleg),
39+
])
3840

3941
for prob in nlls_problems, solver in solvers
4042
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)

0 commit comments

Comments
 (0)