Skip to content

Commit af3e026

Browse files
committed
Make LM and GN oop versions work with linearSolve.jl
1 parent 57238ac commit af3e026

File tree

7 files changed

+126
-115
lines changed

7 files changed

+126
-115
lines changed

src/gaussnewton.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg::G
8282
else
8383
fu1 = f(u, p)
8484
end
85-
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))
86-
87-
JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2))
88-
Jᵀf = zero(u)
85+
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);
86+
linsolve_with_JᵀJ = Val(true))
8987

9088
return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
9189
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
@@ -120,6 +118,7 @@ function perform_step!(cache::GaussNewtonCache{false})
120118
@unpack u, fu1, f, p, alg, linsolve = cache
121119

122120
cache.J = jacobian!!(cache.J, cache)
121+
123122
cache.JᵀJ = cache.J' * cache.J
124123
cache.Jᵀf = cache.J' * fu1
125124
# u = u - J \ fu

src/jacobian.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))
5050

5151
# Build Jacobian Caches
5252
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{iip};
53-
linsolve_kwargs = (;)) where {iip}
53+
linsolve_kwargs = (;),
54+
linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ}
5455
uf = JacobianWrapper{iip}(f, p)
5556

5657
haslinsolve = hasfield(typeof(alg), :linsolve)
@@ -85,7 +86,15 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
8586
end
8687

8788
du = _mutable_zero(u)
88-
linprob = LinearProblem(J, _vec(fu); u0 = _vec(du))
89+
90+
if needsJᵀJ
91+
JᵀJ = __init_JᵀJ(J)
92+
# FIXME: This needs to be handled better for JacVec Operator
93+
Jᵀfu = J' * fu
94+
end
95+
96+
linprob = LinearProblem(needsJᵀJ ? JᵀJ : J, needsJᵀJ ? _vec(Jᵀfu) : _vec(fu);
97+
u0 = _vec(du))
8998

9099
weight = similar(u)
91100
recursivefill!(weight, true)
@@ -95,6 +104,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
95104
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
96105
linsolve_kwargs...)
97106

107+
needsJᵀJ && return uf, linsolve, J, fu, jac_cache, du, JᵀJ, Jᵀfu
98108
return uf, linsolve, J, fu, jac_cache, du
99109
end
100110

@@ -103,6 +113,10 @@ __get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
103113
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
104114
__get_nonsparse_ad(ad) = ad
105115

116+
__init_JᵀJ(J::Number) = zero(J)
117+
__init_JᵀJ(J::AbstractArray) = zeros(eltype(J), size(J, 2), size(J, 2))
118+
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
119+
106120
## Special Handling for Scalars
107121
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
108122
::Val{false}; kwargs...)

src/levenberg.jl

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,6 @@ numerically-difficult nonlinear systems.
7272
where `J` is the Jacobian. It is suggested by
7373
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
7474
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
75-
76-
!!! warning
77-
78-
`linsolve` and `precs` are used exclusively for the inplace version of the algorithm.
79-
Support for the OOP version is planned!
8075
"""
8176
@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD}
8277
ad::AD
@@ -102,18 +97,17 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
10297
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
10398
end
10499

105-
@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType} <:
106-
AbstractNonlinearSolveCache{iip}
100+
@concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
107101
f
108102
alg
109-
u::uType
103+
u
110104
fu1
111105
fu2
112106
du
113107
p
114108
uf
115109
linsolve
116-
J::jType
110+
J
117111
jac_cache
118112
force_stop::Bool
119113
maxiters::Int
@@ -122,27 +116,27 @@ end
122116
abstol
123117
prob
124118
DᵀD
125-
JᵀJ::jType
126-
λ::λType
127-
λ_factor::λType
128-
damping_increase_factor::λType
129-
damping_decrease_factor::λType
130-
h::λType
131-
α_geodesic::λType
132-
b_uphill::λType
133-
min_damping_D::λType
134-
v::uType
135-
a::uType
136-
tmp_vec::uType
137-
v_old::uType
138-
norm_v_old::lossType
139-
δ::uType
140-
loss_old::lossType
119+
JᵀJ
120+
λ
121+
λ_factor
122+
damping_increase_factor
123+
damping_decrease_factor
124+
h
125+
α_geodesic
126+
b_uphill
127+
min_damping_D
128+
v
129+
a
130+
tmp_vec
131+
v_old
132+
norm_v_old
133+
δ
134+
loss_old
141135
make_new_J::Bool
142136
fu_tmp
143137
u_tmp
144138
Jv
145-
mat_tmp::jType
139+
mat_tmp
146140
stats::NLStats
147141
end
148142

@@ -153,8 +147,8 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
153147
@unpack f, u0, p = prob
154148
u = alias_u0 ? u0 : deepcopy(u0)
155149
fu1 = evaluate_f(prob, u)
156-
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
157-
linsolve_kwargs)
150+
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p, Val(iip);
151+
linsolve_kwargs, linsolve_with_JᵀJ=Val(true))
158152

159153
λ = convert(eltype(u), alg.damping_initial)
160154
λ_factor = convert(eltype(u), alg.damping_increase_factor)
@@ -174,12 +168,10 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
174168
end
175169

176170
loss = internalnorm(fu1)
177-
JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2))
178-
v = zero(u)
179-
a = zero(u)
180-
tmp_vec = zero(u)
181-
v_old = zero(u)
182-
δ = zero(u)
171+
a = _mutable_zero(u)
172+
tmp_vec = _mutable_zero(u)
173+
v_old = _mutable_zero(u)
174+
δ = _mutable_zero(u)
183175
make_new_J = true
184176
fu_tmp = zero(fu1)
185177
mat_tmp = zero(JᵀJ)
@@ -223,7 +215,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
223215
# The following lines do: cache.a = -J \ cache.fu_tmp
224216
mul!(cache.Jv, J, v)
225217
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
226-
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.fu_tmp),
218+
mul!(cache.u_tmp, J', cache.fu_tmp)
219+
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
227220
linu = _vec(cache.du), p = p, reltol = cache.abstol)
228221
cache.linsolve = linres.cache
229222
@. cache.a = -cache.du
@@ -279,15 +272,30 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
279272
cache.make_new_J = false
280273
cache.stats.njacs += 1
281274
end
282-
@unpack u, p, λ, JᵀJ, DᵀD, J = cache
275+
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache
283276

284277
cache.mat_tmp = JᵀJ + λ * DᵀD
285278
# Usual Levenberg-Marquardt step ("velocity").
286-
cache.v = -cache.mat_tmp \ (J' * fu1)
279+
if linsolve === nothing
280+
cache.v = -cache.mat_tmp \ (J' * fu1)
281+
else
282+
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp, b = _vec(J' * fu1),
283+
linu = _vec(cache.v), p, reltol = cache.abstol)
284+
cache.linsolve = linres.cache
285+
end
287286

288287
@unpack v, h, α_geodesic = cache
289288
# Geodesic acceleration (step_size = v + a / 2).
290-
cache.a = -cache.mat_tmp \ ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v))
289+
if linsolve === nothing
290+
cache.a = -cache.mat_tmp \
291+
_vec(J' * ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))
292+
else
293+
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp,
294+
b = _mutable(_vec(J' *
295+
((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
296+
linu = _vec(cache.a), p, reltol = cache.abstol)
297+
cache.linsolve = linres.cache
298+
end
291299
cache.stats.nsolve += 1
292300
cache.stats.nfactors += 1
293301

src/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ _mutable_zero(x::SArray) = MArray(x)
138138

139139
_mutable(x) = x
140140
_mutable(x::SArray) = MArray(x)
141+
141142
_maybe_mutable(x, ::AbstractFiniteDifferencesMode) = _mutable(x)
142143
# The shadow allocated for Enzyme needs to be mutable
143144
_maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x)

test/23_test_problems.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
using NonlinearSolve, LinearAlgebra, NonlinearProblemLibrary, Test
1+
using NonlinearSolve, LinearAlgebra, LinearSolve, NonlinearProblemLibrary, Test
22

33
problems = NonlinearProblemLibrary.problems
44
dicts = NonlinearProblemLibrary.dicts
55

6-
function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-5)
6+
function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-4)
77
for (idx, (problem, dict)) in enumerate(zip(problems, dicts))
88
x = dict["start"]
99
res = similar(x)
1010
nlprob = NonlinearProblem(problem, x)
1111
@testset "$(dict["title"])" begin
1212
for alg in alg_ops
13-
sol = solve(nlprob, alg, abstol = 1e-15, reltol = 1e-15)
13+
sol = solve(nlprob, alg, abstol = 1e-18, reltol = 1e-18)
1414
problem(res, sol.u, nothing)
1515
broken = idx in broken_tests[alg] ? true : false
1616
@test norm(res)ϵ broken=broken
@@ -43,19 +43,20 @@ end
4343
broken_tests[alg_ops[1]] = [6, 11, 21]
4444
broken_tests[alg_ops[2]] = [6, 11, 21]
4545
broken_tests[alg_ops[3]] = [1, 6, 11, 12, 15, 16, 21]
46-
broken_tests[alg_ops[4]] = [1, 6, 8, 11, 15, 16, 21, 22]
46+
broken_tests[alg_ops[4]] = [1, 6, 8, 11, 16, 21, 22]
4747
broken_tests[alg_ops[5]] = [6, 21]
4848
broken_tests[alg_ops[6]] = [6, 21]
4949

5050
test_on_library(problems, dicts, alg_ops, broken_tests)
5151
end
5252

5353
@testset "TrustRegion test problem library" begin
54-
alg_ops = (LevenbergMarquardt(), LevenbergMarquardt(; α_geodesic = 0.5))
54+
alg_ops = (LevenbergMarquardt(; linsolve=NormalCholeskyFactorization()),
55+
LevenbergMarquardt(; α_geodesic = 0.1, linsolve=NormalCholeskyFactorization()))
5556

5657
# dictionary with indices of test problems where method does not converge to small residual
5758
broken_tests = Dict(alg => Int[] for alg in alg_ops)
58-
broken_tests[alg_ops[1]] = [3, 6, 11, 17, 21]
59+
broken_tests[alg_ops[1]] = [3, 6, 11, 21]
5960
broken_tests[alg_ops[2]] = [3, 6, 11, 21]
6061

6162
test_on_library(problems, dicts, alg_ops, broken_tests)

0 commit comments

Comments
 (0)