Skip to content

Commit 960dad5

Browse files
committed
Default to QR for GaussNewton
1 parent b300050 commit 960dad5

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "2.4.0"
4+
version = "2.5.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/gaussnewton.jl

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
4646
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
4747
end
4848

49-
function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
49+
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
5050
precs = DEFAULT_PRECS, adkwargs...)
5151
ad = default_adargs_to_adtype(; adkwargs...)
5252
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
@@ -81,15 +81,31 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
8181
kwargs...) where {uType, iip}
8282
alg = get_concrete_algorithm(alg_, prob)
8383
@unpack f, u0, p = prob
84+
85+
# Use QR if the user did not specify a linear solver
86+
if alg.linsolve === nothing || alg.linsolve isa QRFactorization ||
87+
alg.linsolve isa FastQRFactorization
88+
linsolve_with_JᵀJ = Val(false)
89+
else
90+
linsolve_with_JᵀJ = Val(true)
91+
end
92+
8493
u = alias_u0 ? u0 : deepcopy(u0)
8594
if iip
8695
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
8796
f(fu1, u, p)
8897
else
8998
fu1 = f(u, p)
9099
end
91-
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);
92-
linsolve_with_JᵀJ = Val(true))
100+
101+
if SciMLBase._unwrap_val(linsolve_with_JᵀJ)
102+
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p,
103+
Val(iip); linsolve_with_JᵀJ)
104+
else
105+
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p,
106+
Val(iip); linsolve_with_JᵀJ)
107+
JᵀJ, Jᵀf = nothing, nothing
108+
end
93109

94110
return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
95111
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
@@ -99,12 +115,20 @@ end
99115
function perform_step!(cache::GaussNewtonCache{true})
100116
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
101117
jacobian!!(J, cache)
102-
__matmul!(JᵀJ, J', J)
103-
__matmul!(Jᵀf, J', fu1)
104118

105-
# u = u - J \ fu
106-
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
107-
linu = _vec(du), p, reltol = cache.abstol)
119+
if JᵀJ !== nothing
120+
__matmul!(JᵀJ, J', J)
121+
__matmul!(Jᵀf, J', fu1)
122+
end
123+
124+
# u = u - JᵀJ \ Jᵀfu
125+
if cache.JᵀJ === nothing
126+
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
127+
p, reltol = cache.abstol)
128+
else
129+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
130+
linu = _vec(du), p, reltol = cache.abstol)
131+
end
108132
cache.linsolve = linres.cache
109133
@. u = u - du
110134
f(cache.fu_new, u, p)
@@ -125,14 +149,22 @@ function perform_step!(cache::GaussNewtonCache{false})
125149

126150
cache.J = jacobian!!(cache.J, cache)
127151

128-
cache.JᵀJ = cache.J' * cache.J
129-
cache.Jᵀf = cache.J' * fu1
152+
if cache.JᵀJ !== nothing
153+
cache.JᵀJ = cache.J' * cache.J
154+
cache.Jᵀf = cache.J' * fu1
155+
end
156+
130157
# u = u - J \ fu
131158
if linsolve === nothing
132159
cache.du = fu1 / cache.J
133160
else
134-
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),
135-
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
161+
if cache.JᵀJ === nothing
162+
linres = dolinsolve(alg.precs, linsolve; A = cache.J, b = _vec(fu1),
163+
linu = _vec(cache.du), p, reltol = cache.abstol)
164+
else
165+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),
166+
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
167+
end
136168
cache.linsolve = linres.cache
137169
end
138170
cache.u = @. u - cache.du # `u` might not support mutation

0 commit comments

Comments
 (0)