Skip to content

Commit 081a18a

Browse files
committed
Add gauss newton
1 parent b8aca89 commit 081a18a

File tree

7 files changed

+209
-8
lines changed

7 files changed

+209
-8
lines changed

docs/pages.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ pages = ["index.md",
1212
"basics/FAQ.md"],
1313
"Solver Summaries and Recommendations" => Any["solvers/NonlinearSystemSolvers.md",
1414
"solvers/BracketingSolvers.md",
15-
"solvers/SteadyStateSolvers.md"],
15+
"solvers/SteadyStateSolvers.md",
16+
"solvers/NonlinearLeastSquaresSolvers.md"],
1617
"Detailed Solver APIs" => Any["api/nonlinearsolve.md",
1718
"api/simplenonlinearsolve.md",
1819
"api/minpack.md",

docs/src/api/nonlinearsolve.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ These are the native solvers of NonlinearSolve.jl.
77
```@docs
88
NewtonRaphson
99
TrustRegion
10+
LevenbergMarquardt
11+
GaussNewton
1012
```
1113

1214
## Radius Update Schemes for Trust Region (RadiusUpdateSchemes)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Nonlinear Least Squares Solvers
2+
3+
`solve(prob::NonlinearLeastSquaresProblem, alg; kwargs...)`
4+
5+
Solves the nonlinear least squares problem defined by `prob` using the algorithm
6+
`alg`. If no algorithm is given, a default algorithm will be chosen.
7+
8+
## Recommended Methods
9+
10+
`LevenbergMarquardt` is a good choice for most problems.
11+
12+
## Full List of Methods
13+
14+
- `LevenbergMarquardt()`: An advanced Levenberg-Marquardt implementation with the
15+
improvements suggested in the [paper](https://arxiv.org/abs/1201.5885) "Improvements to
16+
the Levenberg-Marquardt algorithm for nonlinear least-squares minimization". Designed for
17+
large-scale and numerically-difficult nonlinear systems.
18+
- `GaussNewton()`: An advanced GaussNewton implementation with support for efficient
19+
handling of sparse matrices via colored automatic differentiation and preconditioned
20+
linear solvers. Designed for large-scale and numerically-difficult nonlinear least squares
21+
problems.
22+
23+
## Example usage
24+
25+
```julia
26+
using NonlinearSolve
27+
sol = solve(prob, LevenbergMarquardt())
28+
```

docs/src/solvers/NonlinearSystemSolvers.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ features, but have a bit of overhead on very small problems.
4242
methods for high performance on large and sparse systems.
4343
- `TrustRegion()`: A Newton Trust Region dogleg method with swappable nonlinear solvers and
4444
autodiff methods for high performance on large and sparse systems.
45+
- `LevenbergMarquardt()`: An advanced Levenberg-Marquardt implementation with the
46+
improvements suggested in the [paper](https://arxiv.org/abs/1201.5885) "Improvements to
47+
the Levenberg-Marquardt algorithm for nonlinear least-squares minimization". Designed for
48+
large-scale and numerically-difficult nonlinear systems.
4549

4650
### SimpleNonlinearSolve.jl
4751

src/NonlinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ include("linesearch.jl")
6464
include("raphson.jl")
6565
include("trustRegion.jl")
6666
include("levenberg.jl")
67+
include("gaussnewton.jl")
6768
include("jacobian.jl")
6869
include("ad.jl")
6970

@@ -91,7 +92,7 @@ end
9192

9293
export RadiusUpdateSchemes
9394

94-
export NewtonRaphson, TrustRegion, LevenbergMarquardt
95+
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
9596

9697
export LineSearch
9798

src/gaussnewton.jl

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""
2+
GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
3+
adkwargs...)
4+
5+
An advanced GaussNewton implementation with support for efficient handling of sparse
6+
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
7+
for large-scale and numerically-difficult nonlinear least squares problems.
8+
9+
!!! note
10+
In most practical situations, users should prefer using `LevenbergMarquardt` instead! It
11+
is a more general extension of `Gauss-Newton` Method.
12+
13+
### Keyword Arguments
14+
15+
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
16+
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
17+
`AutoForwardDiff()`. Valid choices are types from ADTypes.jl.
18+
- `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is used,
19+
then the Jacobian will not be constructed and instead direct Jacobian-vector products
20+
`J*v` are computed using forward-mode automatic differentiation or finite differencing
21+
tricks (without ever constructing the Jacobian). However, if the Jacobian is still needed,
22+
for example for a preconditioner, `concrete_jac = true` can be passed in order to force
23+
the construction of the Jacobian.
24+
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
25+
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
26+
LinearSolve.jl default algorithm choice. For more information on available algorithm
27+
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
28+
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
29+
preconditioners. For more information on specifying preconditioners for LinearSolve
30+
algorithms, consult the
31+
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
32+
33+
!!! warning
34+
35+
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
36+
construction. This will be fixed in the near future.
37+
"""
38+
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
39+
ad::AD
40+
linsolve
41+
precs
42+
end
43+
44+
function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(),
45+
precs = DEFAULT_PRECS, adkwargs...)
46+
ad = default_adargs_to_adtype(; adkwargs...)
47+
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
48+
end
49+
50+
@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
51+
f
52+
alg
53+
u
54+
fu1
55+
fu2
56+
fu_new
57+
du
58+
p
59+
uf
60+
linsolve
61+
J
62+
JᵀJ
63+
Jᵀf
64+
jac_cache
65+
force_stop
66+
maxiters::Int
67+
internalnorm
68+
retcode::ReturnCode.T
69+
abstol
70+
prob
71+
stats::NLStats
72+
end
73+
74+
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg::GaussNewton,
75+
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
76+
kwargs...) where {uType, iip}
77+
@unpack f, u0, p = prob
78+
u = alias_u0 ? u0 : deepcopy(u0)
79+
if iip
80+
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
81+
f(fu1, u, p)
82+
else
83+
fu1 = f(u, p)
84+
end
85+
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))
86+
87+
JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2))
88+
Jᵀf = zero(u)
89+
90+
return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
91+
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
92+
prob, NLStats(1, 0, 0, 0, 0))
93+
end
94+
95+
function perform_step!(cache::GaussNewtonCache{true})
96+
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
97+
jacobian!!(J, cache)
98+
mul!(JᵀJ, J', J)
99+
mul!(Jᵀf, J', fu1)
100+
101+
# u = u - J \ fu
102+
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
103+
p, reltol = cache.abstol)
104+
cache.linsolve = linres.cache
105+
@. u = u - du
106+
f(cache.fu_new, u, p)
107+
108+
(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
109+
cache.internalnorm(cache.fu_new) < cache.abstol) &&
110+
(cache.force_stop = true)
111+
cache.fu1 .= cache.fu_new
112+
cache.stats.nf += 1
113+
cache.stats.njacs += 1
114+
cache.stats.nsolve += 1
115+
cache.stats.nfactors += 1
116+
return nothing
117+
end
118+
119+
function perform_step!(cache::GaussNewtonCache{false})
120+
@unpack u, fu1, f, p, alg, linsolve = cache
121+
122+
cache.J = jacobian!!(cache.J, cache)
123+
cache.JᵀJ = cache.J' * cache.J
124+
cache.Jᵀf = cache.J' * fu1
125+
# u = u - J \ fu
126+
if linsolve === nothing
127+
cache.du = fu1 / cache.J
128+
else
129+
linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf),
130+
linu = _vec(cache.du), p, reltol = cache.abstol)
131+
cache.linsolve = linres.cache
132+
end
133+
cache.u = @. u - cache.du # `u` might not support mutation
134+
cache.fu_new = f(cache.u, p)
135+
136+
(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
137+
cache.internalnorm(cache.fu_new) < cache.abstol) &&
138+
(cache.force_stop = true)
139+
cache.fu1 = cache.fu_new
140+
cache.stats.nf += 1
141+
cache.stats.njacs += 1
142+
cache.stats.nsolve += 1
143+
cache.stats.nfactors += 1
144+
return nothing
145+
end
146+
147+
function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p,
148+
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
149+
cache.p = p
150+
if iip
151+
recursivecopy!(cache.u, u0)
152+
cache.f(cache.fu1, cache.u, p)
153+
else
154+
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
155+
cache.u = u0
156+
cache.fu1 = cache.f(cache.u, p)
157+
end
158+
cache.abstol = abstol
159+
cache.maxiters = maxiters
160+
cache.stats.nf = 1
161+
cache.stats.nsteps = 1
162+
cache.force_stop = false
163+
cache.retcode = ReturnCode.Default
164+
return cache
165+
end

test/nonlinear_least_squares.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
2525
prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
2626
resid_prototype = zero(y_target)), θ_init, x)
2727

28-
# sol = solve(prob_oop, GaussNewton(); maxiters = 1000, abstol = 1e-8)
29-
# @test SciMLBase.successful_retcode(sol)
30-
# @test norm(sol.resid) < 1e-6
28+
sol = solve(prob_oop, GaussNewton(); maxiters = 1000, abstol = 1e-8)
29+
@test SciMLBase.successful_retcode(sol)
30+
@test norm(sol.resid) < 1e-6
3131

32-
# sol = solve(prob_iip, GaussNewton(); maxiters = 1000, abstol = 1e-8)
33-
# @test SciMLBase.successful_retcode(sol)
34-
# @test norm(sol.resid) < 1e-6
32+
sol = solve(prob_iip, GaussNewton(); maxiters = 1000, abstol = 1e-8)
33+
@test SciMLBase.successful_retcode(sol)
34+
@test norm(sol.resid) < 1e-6
3535

3636
sol = solve(prob_oop, LevenbergMarquardt(); maxiters = 1000, abstol = 1e-8)
3737
@test SciMLBase.successful_retcode(sol)

0 commit comments

Comments
 (0)