Skip to content

Commit 786e8d3

Browse files
Merge pull request #258 from avik-pal/ap/fix_gauss_newton
Gauss Newton & LM Robustness Fixes
2 parents b300050 + 28ea63a commit 786e8d3

19 files changed

+276
-124
lines changed

.github/workflows/CI.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ jobs:
1414
test:
1515
runs-on: ubuntu-latest
1616
strategy:
17+
fail-fast: false
1718
matrix:
1819
group:
19-
- Core
20-
- 23TestProblems
20+
- All
2121
version:
2222
- '1'
23+
- '~1.10.0-0'
2324
steps:
2425
- uses: actions/checkout@v4
2526
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "2.4.0"
4+
version = "2.5.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -44,7 +44,7 @@ FiniteDiff = "2"
4444
ForwardDiff = "0.10.3"
4545
LeastSquaresOptim = "0.8"
4646
LineSearches = "7"
47-
LinearSolve = "2"
47+
LinearSolve = "2.12"
4848
NonlinearProblemLibrary = "0.1"
4949
PrecompileTools = "1"
5050
RecursiveArrayTools = "2"
@@ -65,6 +65,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6565
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
6666
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6767
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
68+
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
6869
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
6970
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7071
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -76,4 +77,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7677
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7778

7879
[targets]
79-
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt"]
80+
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath"]

docs/src/solvers/NonlinearLeastSquaresSolvers.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ Solves the nonlinear least squares problem defined by `prob` using the algorithm
1919
handling of sparse matrices via colored automatic differentiation and preconditioned
2020
linear solvers. Designed for large-scale and numerically-difficult nonlinear least squares
2121
problems.
22-
- `SimpleNewtonRaphson()`: Newton Raphson implementation that uses Linear Least Squares
23-
solution at every step to compute the descent direction. **WARNING**: This method is not
24-
a robust solver for nonlinear least squares problems. The computed delta step might not
25-
be the correct descent direction!
22+
- `SimpleNewtonRaphson()`: Simple Gauss Newton Implementation with `QRFactorization` to
23+
solve a linear least squares problem at each step!
2624

2725
## Example usage
2826

docs/src/solvers/NonlinearSystemSolvers.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ features, but have a bit of overhead on very small problems.
7171
- `GeneralKlement()`: Generalization of Klement's Quasi-Newton Method with Line Search and
7272
Automatic Jacobian Resetting. This is a fast method but unstable when the condition number of
7373
the Jacobian matrix is sufficiently large.
74+
- `LimitedMemoryBroyden()`: An advanced version of `LBroyden` which uses a limited memory
75+
Broyden method. This is a fast method but unstable when the condition number of
76+
the Jacobian matrix is sufficiently large. It is recommended to use `GeneralBroyden` or
77+
`GeneralKlement` instead unless the memory usage is a concern.
7478

7579
### SimpleNonlinearSolve.jl
7680

src/broyden.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
6666
alg.reset_tolerance
6767
reset_check = x -> abs(x) reset_tolerance
6868
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),
69-
zero(fu), p, J⁻¹, zero(_vec(fu)'), _mutable_zero(u), false, 0, alg.max_resets,
70-
maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
69+
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
70+
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
7171
reset_check, prob, NLStats(1, 0, 0, 0, 0),
7272
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
7373
end
@@ -78,7 +78,7 @@ function perform_step!(cache::GeneralBroydenCache{true})
7878

7979
mul!(_vec(du), J⁻¹, -_vec(fu))
8080
α = perform_linesearch!(cache.lscache, u, du)
81-
axpy!(α, du, u)
81+
_axpy!(α, du, u)
8282
f(fu2, u, p)
8383

8484
cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true)
@@ -101,7 +101,8 @@ function perform_step!(cache::GeneralBroydenCache{true})
101101
else
102102
mul!(_vec(J⁻¹df), J⁻¹, _vec(dfu))
103103
mul!(J⁻¹₂, _vec(du)', J⁻¹)
104-
du .= (du .- J⁻¹df) ./ (dot(du, J⁻¹df) .+ T(1e-5))
104+
denom = dot(du, J⁻¹df)
105+
du .= (du .- J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
105106
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
106107
end
107108
fu .= fu2
@@ -136,7 +137,8 @@ function perform_step!(cache::GeneralBroydenCache{false})
136137
else
137138
cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu))
138139
cache.J⁻¹₂ = _vec(cache.du)' * cache.J⁻¹
139-
cache.du = (cache.du .- cache.J⁻¹df) ./ (dot(cache.du, cache.J⁻¹df) .+ T(1e-5))
140+
denom = dot(cache.du, cache.J⁻¹df)
141+
cache.du = (cache.du .- cache.J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
140142
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
141143
end
142144
cache.fu = cache.fu2

src/default.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ end
159159
]
160160
else
161161
[
162-
:(GeneralBroyden()),
163162
:(GeneralKlement()),
163+
:(GeneralBroyden()),
164164
:(NewtonRaphson(; linsolve, precs, adkwargs...)),
165165
:(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)),
166166
:(TrustRegion(; linsolve, precs, adkwargs...)),

src/gaussnewton.jl

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
4646
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
4747
end
4848

49-
function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
49+
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
5050
precs = DEFAULT_PRECS, adkwargs...)
5151
ad = default_adargs_to_adtype(; adkwargs...)
5252
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
@@ -81,15 +81,25 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
8181
kwargs...) where {uType, iip}
8282
alg = get_concrete_algorithm(alg_, prob)
8383
@unpack f, u0, p = prob
84+
85+
linsolve_with_JᵀJ = Val(_needs_square_A(alg, u0))
86+
8487
u = alias_u0 ? u0 : deepcopy(u0)
8588
if iip
8689
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
8790
f(fu1, u, p)
8891
else
8992
fu1 = f(u, p)
9093
end
91-
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);
92-
linsolve_with_JᵀJ = Val(true))
94+
95+
if SciMLBase._unwrap_val(linsolve_with_JᵀJ)
96+
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p,
97+
Val(iip); linsolve_with_JᵀJ)
98+
else
99+
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p,
100+
Val(iip); linsolve_with_JᵀJ)
101+
JᵀJ, Jᵀf = nothing, nothing
102+
end
93103

94104
return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
95105
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
@@ -99,12 +109,20 @@ end
99109
function perform_step!(cache::GaussNewtonCache{true})
100110
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
101111
jacobian!!(J, cache)
102-
__matmul!(JᵀJ, J', J)
103-
__matmul!(Jᵀf, J', fu1)
104112

105-
# u = u - J \ fu
106-
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
107-
linu = _vec(du), p, reltol = cache.abstol)
113+
if JᵀJ !== nothing
114+
__matmul!(JᵀJ, J', J)
115+
__matmul!(Jᵀf, J', fu1)
116+
end
117+
118+
# u = u - JᵀJ \ Jᵀfu
119+
if cache.JᵀJ === nothing
120+
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
121+
p, reltol = cache.abstol)
122+
else
123+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
124+
linu = _vec(du), p, reltol = cache.abstol)
125+
end
108126
cache.linsolve = linres.cache
109127
@. u = u - du
110128
f(cache.fu_new, u, p)
@@ -125,14 +143,22 @@ function perform_step!(cache::GaussNewtonCache{false})
125143

126144
cache.J = jacobian!!(cache.J, cache)
127145

128-
cache.JᵀJ = cache.J' * cache.J
129-
cache.Jᵀf = cache.J' * fu1
146+
if cache.JᵀJ !== nothing
147+
cache.JᵀJ = cache.J' * cache.J
148+
cache.Jᵀf = cache.J' * fu1
149+
end
150+
130151
# u = u - J \ fu
131152
if linsolve === nothing
132153
cache.du = fu1 / cache.J
133154
else
134-
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),
135-
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
155+
if cache.JᵀJ === nothing
156+
linres = dolinsolve(alg.precs, linsolve; A = cache.J, b = _vec(fu1),
157+
linu = _vec(cache.du), p, reltol = cache.abstol)
158+
else
159+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),
160+
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
161+
end
136162
cache.linsolve = linres.cache
137163
end
138164
cache.u = @. u - cache.du # `u` might not support mutation

src/jacobian.jl

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))
5050

5151
# Build Jacobian Caches
5252
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{iip};
53-
linsolve_kwargs = (;),
54-
linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ}
53+
linsolve_kwargs = (;), lininit::Val{linsolve_init} = Val(true),
54+
linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ, linsolve_init}
5555
uf = JacobianWrapper{iip}(f, p)
5656

5757
haslinsolve = hasfield(typeof(alg), :linsolve)
@@ -95,25 +95,28 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
9595
Jᵀfu = J' * _vec(fu)
9696
end
9797

98-
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
99-
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))
100-
101-
if alg isa PseudoTransient
102-
alpha = convert(eltype(u), alg.alpha_initial)
103-
J_new = J - (1 / alpha) * I
104-
linprob = LinearProblem(J_new, _vec(fu); u0 = _vec(du))
98+
if linsolve_init
99+
linprob_A = alg isa PseudoTransient ?
100+
(J - (1 / (convert(eltype(u), alg.alpha_initial))) * I) :
101+
(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J)
102+
linsolve = __setup_linsolve(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg)
103+
else
104+
linsolve = nothing
105105
end
106106

107+
needsJᵀJ && return uf, linsolve, J, fu, jac_cache, du, JᵀJ, Jᵀfu
108+
return uf, linsolve, J, fu, jac_cache, du
109+
end
110+
111+
function __setup_linsolve(A, b, u, p, alg)
112+
linprob = LinearProblem(A, _vec(b); u0 = _vec(u))
113+
107114
weight = similar(u)
108115
recursivefill!(weight, true)
109116

110-
Pl, Pr = wrapprecs(alg.precs(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J, nothing, u, p,
111-
nothing, nothing, nothing, nothing, nothing)..., weight)
112-
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
113-
linsolve_kwargs...)
114-
115-
needsJᵀJ && return uf, linsolve, J, fu, jac_cache, du, JᵀJ, Jᵀfu
116-
return uf, linsolve, J, fu, jac_cache, du
117+
Pl, Pr = wrapprecs(alg.precs(A, nothing, u, p, nothing, nothing, nothing, nothing,
118+
nothing)..., weight)
119+
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
117120
end
118121

119122
__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()

src/klement.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ solves.
2727
linesearch
2828
end
2929

30+
function set_linsolve(alg::GeneralKlement, linsolve)
31+
return GeneralKlement(alg.max_resets, linsolve, alg.precs, alg.linesearch)
32+
end
33+
3034
function GeneralKlement(; max_resets::Int = 5, linsolve = nothing,
3135
linesearch = LineSearch(), precs = DEFAULT_PRECS)
3236
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
@@ -60,30 +64,27 @@ end
6064

6165
get_fu(cache::GeneralKlementCache) = cache.fu
6266

63-
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralKlement, args...;
67+
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKlement, args...;
6468
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
6569
linsolve_kwargs = (;), kwargs...) where {uType, iip}
6670
@unpack f, u0, p = prob
6771
u = alias_u0 ? u0 : deepcopy(u0)
6872
fu = evaluate_f(prob, u)
6973
J = __init_identity_jacobian(u, fu)
74+
du = _mutable_zero(u)
7075

7176
if u isa Number
7277
linsolve = nothing
78+
alg = alg_
7379
else
7480
# For General Julia Arrays default to LU Factorization
75-
linsolve_alg = alg.linsolve === nothing && u isa Array ? LUFactorization() :
81+
linsolve_alg = alg_.linsolve === nothing && u isa Array ? LUFactorization() :
7682
nothing
77-
weight = similar(u)
78-
recursivefill!(weight, true)
79-
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
80-
nothing)..., weight)
81-
linprob = LinearProblem(J, _vec(fu); u0 = _vec(fu))
82-
linsolve = init(linprob, linsolve_alg; alias_A = true, alias_b = true, Pl, Pr,
83-
linsolve_kwargs...)
83+
alg = set_linsolve(alg_, linsolve_alg)
84+
linsolve = __setup_linsolve(J, _vec(fu), _vec(du), p, alg)
8485
end
8586

86-
return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), _mutable_zero(u), p, linsolve,
87+
return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), du, p, linsolve,
8788
J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false,
8889
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
8990
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
@@ -114,7 +115,7 @@ function perform_step!(cache::GeneralKlementCache{true})
114115

115116
# Line Search
116117
α = perform_linesearch!(cache.lscache, u, du)
117-
axpy!(α, du, u)
118+
_axpy!(α, du, u)
118119
f(cache.fu2, u, p)
119120

120121
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)

src/lbroyden.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
9393
T = eltype(u)
9494

9595
α = perform_linesearch!(cache.lscache, u, du)
96-
axpy!(α, du, u)
96+
_axpy!(α, du, u)
9797
f(cache.fu2, u, p)
9898

9999
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
@@ -123,8 +123,8 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
123123
__lbroyden_matvec!(_vec(cache.vᵀ_cache), cache.Ux, U_part, Vᵀ_part, _vec(cache.du))
124124
__lbroyden_rmatvec!(_vec(cache.u_cache), cache.xᵀVᵀ, U_part, Vᵀ_part,
125125
_vec(cache.dfu))
126-
cache.u_cache .= (du .- cache.u_cache) ./
127-
(dot(cache.vᵀ_cache, cache.dfu) .+ T(1e-5))
126+
denom = dot(cache.vᵀ_cache, cache.dfu)
127+
cache.u_cache .= (du .- cache.u_cache) ./ ifelse(iszero(denom), T(1e-5), denom)
128128

129129
idx = mod1(cache.iterations_since_reset + 1, size(cache.U, 1))
130130
selectdim(cache.U, 1, idx) .= _vec(cache.u_cache)
@@ -179,8 +179,8 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false})
179179
__lbroyden_matvec(U_part, Vᵀ_part, _vec(cache.du)))
180180
cache.u_cache = _restructure(cache.u_cache,
181181
__lbroyden_rmatvec(U_part, Vᵀ_part, _vec(cache.dfu)))
182-
cache.u_cache = (cache.du .- cache.u_cache) ./
183-
(dot(cache.vᵀ_cache, cache.dfu) .+ T(1e-5))
182+
denom = dot(cache.vᵀ_cache, cache.dfu)
183+
cache.u_cache = (cache.du .- cache.u_cache) ./ ifelse(iszero(denom), T(1e-5), denom)
184184

185185
idx = mod1(cache.iterations_since_reset + 1, size(cache.U, 1))
186186
selectdim(cache.U, 1, idx) .= _vec(cache.u_cache)
@@ -249,12 +249,12 @@ end
249249
return nothing
250250
end
251251
mul!(xᵀVᵀ[:, 1:η], x', Vᵀ)
252-
mul!(y', xᵀVᵀ[:, 1:η], U)
252+
mul!(reshape(y, 1, :), xᵀVᵀ[:, 1:η], U)
253253
return nothing
254254
end
255255

256256
@views function __lbroyden_rmatvec(U::AbstractMatrix, Vᵀ::AbstractMatrix, x::AbstractVector)
257257
# Computes xᵀ × Vᵀ × U
258258
size(U, 1) == 0 && return x
259-
return (x' * Vᵀ) * U
259+
return (reshape(x, 1, :) * Vᵀ) * U
260260
end

0 commit comments

Comments
 (0)