@@ -46,7 +46,7 @@ function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
46
46
return GaussNewton {CJ} (ad, alg. linsolve, alg. precs)
47
47
end
48
48
49
- function GaussNewton (; concrete_jac = nothing , linsolve = CholeskyFactorization () ,
49
+ function GaussNewton (; concrete_jac = nothing , linsolve = nothing ,
50
50
precs = DEFAULT_PRECS, adkwargs... )
51
51
ad = default_adargs_to_adtype (; adkwargs... )
52
52
return GaussNewton {_unwrap_val(concrete_jac)} (ad, linsolve, precs)
@@ -81,15 +81,31 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
81
81
kwargs... ) where {uType, iip}
82
82
alg = get_concrete_algorithm (alg_, prob)
83
83
@unpack f, u0, p = prob
84
+
85
+ # Use QR if the user did not specify a linear solver
86
+ if alg. linsolve === nothing || alg. linsolve isa QRFactorization ||
87
+ alg. linsolve isa FastQRFactorization
88
+ linsolve_with_JᵀJ = Val (false )
89
+ else
90
+ linsolve_with_JᵀJ = Val (true )
91
+ end
92
+
84
93
u = alias_u0 ? u0 : deepcopy (u0)
85
94
if iip
86
95
fu1 = f. resid_prototype === nothing ? zero (u) : f. resid_prototype
87
96
f (fu1, u, p)
88
97
else
89
98
fu1 = f (u, p)
90
99
end
91
- uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches (alg, f, u, p, Val (iip);
92
- linsolve_with_JᵀJ = Val (true ))
100
+
101
+ if SciMLBase. _unwrap_val (linsolve_with_JᵀJ)
102
+ uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches (alg, f, u, p,
103
+ Val (iip); linsolve_with_JᵀJ)
104
+ else
105
+ uf, linsolve, J, fu2, jac_cache, du = jacobian_caches (alg, f, u, p,
106
+ Val (iip); linsolve_with_JᵀJ)
107
+ JᵀJ, Jᵀf = nothing , nothing
108
+ end
93
109
94
110
return GaussNewtonCache {iip} (f, alg, u, fu1, fu2, zero (fu1), du, p, uf, linsolve, J,
95
111
JᵀJ, Jᵀf, jac_cache, false , maxiters, internalnorm, ReturnCode. Default, abstol,
99
115
function perform_step! (cache:: GaussNewtonCache{true} )
100
116
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
101
117
jacobian!! (J, cache)
102
- __matmul! (JᵀJ, J' , J)
103
- __matmul! (Jᵀf, J' , fu1)
104
118
105
- # u = u - J \ fu
106
- linres = dolinsolve (alg. precs, linsolve; A = __maybe_symmetric (JᵀJ), b = _vec (Jᵀf),
107
- linu = _vec (du), p, reltol = cache. abstol)
119
+ if JᵀJ != = nothing
120
+ __matmul! (JᵀJ, J' , J)
121
+ __matmul! (Jᵀf, J' , fu1)
122
+ end
123
+
124
+ # u = u - JᵀJ \ Jᵀfu
125
+ if cache. JᵀJ === nothing
126
+ linres = dolinsolve (alg. precs, linsolve; A = J, b = _vec (fu1), linu = _vec (du),
127
+ p, reltol = cache. abstol)
128
+ else
129
+ linres = dolinsolve (alg. precs, linsolve; A = __maybe_symmetric (JᵀJ), b = _vec (Jᵀf),
130
+ linu = _vec (du), p, reltol = cache. abstol)
131
+ end
108
132
cache. linsolve = linres. cache
109
133
@. u = u - du
110
134
f (cache. fu_new, u, p)
@@ -125,14 +149,22 @@ function perform_step!(cache::GaussNewtonCache{false})
125
149
126
150
cache. J = jacobian!! (cache. J, cache)
127
151
128
- cache. JᵀJ = cache. J' * cache. J
129
- cache. Jᵀf = cache. J' * fu1
152
+ if cache. JᵀJ != = nothing
153
+ cache. JᵀJ = cache. J' * cache. J
154
+ cache. Jᵀf = cache. J' * fu1
155
+ end
156
+
130
157
# u = u - J \ fu
131
158
if linsolve === nothing
132
159
cache. du = fu1 / cache. J
133
160
else
134
- linres = dolinsolve (alg. precs, linsolve; A = __maybe_symmetric (cache. JᵀJ),
135
- b = _vec (cache. Jᵀf), linu = _vec (cache. du), p, reltol = cache. abstol)
161
+ if cache. JᵀJ === nothing
162
+ linres = dolinsolve (alg. precs, linsolve; A = cache. J, b = _vec (fu1),
163
+ linu = _vec (cache. du), p, reltol = cache. abstol)
164
+ else
165
+ linres = dolinsolve (alg. precs, linsolve; A = __maybe_symmetric (cache. JᵀJ),
166
+ b = _vec (cache. Jᵀf), linu = _vec (cache. du), p, reltol = cache. abstol)
167
+ end
136
168
cache. linsolve = linres. cache
137
169
end
138
170
cache. u = @. u - cache. du # `u` might not support mutation
0 commit comments