Skip to content

Commit 77164c7

Browse files
committed
Fix matrix resizing
1 parent 5052a6e commit 77164c7

File tree

3 files changed

+14
-19
lines changed

3 files changed

+14
-19
lines changed

src/gaussnewton.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ function perform_step!(cache::GaussNewtonCache{true})
113113
jacobian!!(J, cache)
114114

115115
if JᵀJ !== nothing
116-
__matmul!(JᵀJ, J', J)
117-
__matmul!(Jᵀf, J', fu1)
116+
__update_JᵀJ!(Val{true}(), cache, :JᵀJ, J)
117+
__update_Jᵀf!(Val{true}(), cache, :Jᵀf, :JᵀJ, J, fu1)
118118
end
119119

120120
# u = u - JᵀJ \ Jᵀfu
@@ -151,8 +151,8 @@ function perform_step!(cache::GaussNewtonCache{false})
151151
cache.J = jacobian!!(cache.J, cache)
152152

153153
if cache.JᵀJ !== nothing
154-
cache.JᵀJ = cache.J' * cache.J
155-
cache.Jᵀf = cache.J' * fu1
154+
__update_JᵀJ!(Val{false}(), cache, :JᵀJ, cache.J)
155+
__update_Jᵀf!(Val{false}(), cache, :Jᵀf, :JᵀJ, cache.J, fu1)
156156
end
157157

158158
# u = u - J \ fu

src/jacobian.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ __maybe_symmetric(x::Number) = x
180180
__maybe_symmetric(x::StaticArray) = x
181181
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
182182
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x
183-
__maybe_symmetric(x::KrylovJᵀJ) = x
183+
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ
184184

185185
## Special Handling for Scalars
186186
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
@@ -204,16 +204,16 @@ function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
204204
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)
205205
end
206206
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
207-
return setproperty!(cache, sym1, J' * fu)
207+
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu))
208208
end
209209
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
210-
return mul!(getproperty(cache, sym1), J', fu)
210+
return mul!(_vec(getproperty(cache, sym1)), J', fu)
211211
end
212212
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
213-
return setproperty!(cache, sym1, H.Jᵀ * fu)
213+
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu))
214214
end
215215
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
216-
return mul!(getproperty(cache, sym1), H.Jᵀ, fu)
216+
return mul!(_vec(getproperty(cache, sym1)), H.Jᵀ, fu)
217217
end
218218

219219
# Left-Right Multiplication

src/trustRegion.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,19 +239,16 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
239239
fu_prev = zero(fu1)
240240

241241
loss = get_loss(fu1)
242-
# uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
243-
# linsolve_kwargs)
244242
uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip);
245243
linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false))
244+
g = _restructure(fu1, g)
246245
linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, du, p, alg)
247246

248247
u_tmp = zero(u)
249248
u_cauchy = zero(u)
250249
u_gauss_newton = _mutable_zero(u)
251250

252251
loss_new = loss
253-
# H = zero(J' * J)
254-
# g = _mutable_zero(fu1)
255252
shrink_counter = 0
256253
fu_new = zero(fu1)
257254
make_new_J = true
@@ -351,9 +348,7 @@ function perform_step!(cache::TrustRegionCache{true})
351348
if cache.make_new_J
352349
jacobian!!(J, cache)
353350
__update_JᵀJ!(Val{true}(), cache, :H, J)
354-
# mul!(cache.H, J', J)
355-
__update_Jᵀf!(Val{true}(), cache, :g, :H, J, fu)
356-
# mul!(_vec(cache.g), J', _vec(fu))
351+
__update_Jᵀf!(Val{true}(), cache, :g, :H, J, _vec(fu))
357352
cache.stats.njacs += 1
358353

359354
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
@@ -383,7 +378,7 @@ function perform_step!(cache::TrustRegionCache{false})
383378
if make_new_J
384379
J = jacobian!!(cache.J, cache)
385380
__update_JᵀJ!(Val{false}(), cache, :H, J)
386-
__update_Jᵀf!(Val{false}(), cache, :g, :H, J, fu)
381+
__update_Jᵀf!(Val{false}(), cache, :g, :H, J, _vec(fu))
387382
cache.stats.njacs += 1
388383

389384
if cache.linsolve === nothing
@@ -420,8 +415,8 @@ function retrospective_step!(cache::TrustRegionCache)
420415
cache.H = J' * J
421416
cache.g = J' * fu
422417
else
423-
mul!(cache.H, J', J)
424-
mul!(cache.g, J', fu)
418+
__update_JᵀJ!(Val{isinplace(cache)}(), cache, :H, J)
419+
__update_Jᵀf!(Val{isinplace(cache)}(), cache, :g, :H, J, fu)
425420
end
426421
cache.stats.njacs += 1
427422
@unpack H, g, du = cache

0 commit comments

Comments
 (0)