Skip to content

Commit f151a0a

Browse files
committed
Use symmetric linear solve if possible
1 parent 8f68ef1 commit f151a0a

File tree

4 files changed

+32
-21
lines changed

4 files changed

+32
-21
lines changed

src/gaussnewton.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
3-
adkwargs...)
2+
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
3+
precs = DEFAULT_PRECS, adkwargs...)
44
55
An advanced GaussNewton implementation with support for efficient handling of sparse
66
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
@@ -41,8 +41,8 @@ for large-scale and numerically-difficult nonlinear least squares problems.
4141
precs
4242
end
4343

44-
function GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
45-
adkwargs...)
44+
function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
45+
precs = DEFAULT_PRECS, adkwargs...)
4646
ad = default_adargs_to_adtype(; adkwargs...)
4747
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
4848
end
@@ -97,8 +97,8 @@ function perform_step!(cache::GaussNewtonCache{true})
9797
__matmul!(Jᵀf, J', fu1)
9898

9999
# u = u - J \ fu
100-
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
101-
p, reltol = cache.abstol)
100+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
101+
linu = _vec(du), p, reltol = cache.abstol)
102102
cache.linsolve = linres.cache
103103
@. u = u - du
104104
f(cache.fu_new, u, p)
@@ -125,8 +125,8 @@ function perform_step!(cache::GaussNewtonCache{false})
125125
if linsolve === nothing
126126
cache.du = fu1 / cache.J
127127
else
128-
linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf),
129-
linu = _vec(cache.du), p, reltol = cache.abstol)
128+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),
129+
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
130130
cache.linsolve = linres.cache
131131
end
132132
cache.u = @. u - cache.du # `u` might not support mutation

src/jacobian.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,14 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
9595
Jᵀfu = J' * fu
9696
end
9797

98-
linprob = LinearProblem(needsJᵀJ ? JᵀJ : J, needsJᵀJ ? _vec(Jᵀfu) : _vec(fu);
99-
u0 = _vec(du))
98+
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
99+
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))
100100

101101
weight = similar(u)
102102
recursivefill!(weight, true)
103103

104-
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
105-
nothing)..., weight)
104+
Pl, Pr = wrapprecs(alg.precs(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J, nothing, u, p,
105+
nothing, nothing, nothing, nothing, nothing)..., weight)
106106
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
107107
linsolve_kwargs...)
108108

@@ -119,6 +119,12 @@ __init_JᵀJ(J::Number) = zero(J)
119119
__init_JᵀJ(J::AbstractArray) = J' * J
120120
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
121121

122+
__maybe_symmetric(x) = Symmetric(x)
123+
__maybe_symmetric(x::Number) = x
124+
# LinearSolve with `nothing` doesn't dispatch correctly here
125+
__maybe_symmetric(x::StaticArray) = x
126+
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
127+
122128
## Special Handling for Scalars
123129
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
124130
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),

src/levenberg.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ numerically-difficult nonlinear systems.
1313
### Keyword Arguments
1414
1515
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
16-
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
17-
`AutoForwardDiff()`. Valid choices are types from ADTypes.jl.
16+
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
17+
`AutoForwardDiff()`. Valid choices are types from ADTypes.jl.
1818
- `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is used,
1919
then the Jacobian will not be constructed and instead direct Jacobian-vector products
2020
`J*v` are computed using forward-mode automatic differentiation or finite differencing
@@ -203,8 +203,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
203203
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
204204
mul!(cache.u_tmp, J', fu1)
205205
@. cache.mat_tmp = JᵀJ + λ * DᵀD
206-
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
207-
linu = _vec(cache.du), p = p, reltol = cache.abstol)
206+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
207+
b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
208208
cache.linsolve = linres.cache
209209
@. cache.v = -cache.du
210210

@@ -280,8 +280,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
280280
if linsolve === nothing
281281
cache.v = -cache.mat_tmp \ (J' * fu1)
282282
else
283-
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp, b = _vec(J' * fu1),
284-
linu = _vec(cache.v), p, reltol = cache.abstol)
283+
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
284+
b = _vec(J' * fu1), linu = _vec(cache.v), p, reltol = cache.abstol)
285285
cache.linsolve = linres.cache
286286
end
287287

@@ -291,7 +291,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
291291
cache.a = -cache.mat_tmp \
292292
_vec(J' * ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))
293293
else
294-
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp,
294+
linres = dolinsolve(alg.precs, linsolve;
295295
b = _mutable(_vec(J' *
296296
((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
297297
linu = _vec(cache.a), p, reltol = cache.abstol)

test/nonlinear_least_squares.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,15 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
2727
resid_prototype = zero(y_target)), θ_init, x)
2828

2929
nlls_problems = [prob_oop, prob_iip]
30-
solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)]
30+
solvers = [
31+
GaussNewton(),
32+
LevenbergMarquardt(),
33+
LSOptimSolver(:lm),
34+
LSOptimSolver(:dogleg),
35+
]
3136

3237
for prob in nlls_problems, solver in solvers
33-
@time sol = solve(prob, solver; maxiters = 1000, abstol = 1e-8)
38+
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
3439
@test SciMLBase.successful_retcode(sol)
3540
@test norm(sol.resid) < 1e-6
3641
end

0 commit comments

Comments
 (0)