Skip to content

Commit 954a799

Browse files
committed
Cleanup Normal Form Equation Construction
1 parent eadf16f commit 954a799

File tree

2 files changed

+13
-23
lines changed

2 files changed

+13
-23
lines changed

src/gaussnewton.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
116116

117117
# Use normal form to solve the Linear Problem
118118
if cache.JᵀJ !== nothing
119-
__update_JᵀJ!(Val{iip}(), cache, :JᵀJ, cache.J)
120-
__update_Jᵀf!(Val{iip}(), cache, :Jᵀf, :JᵀJ, cache.J, cache.fu)
119+
__update_JᵀJ!(cache, Val(:JᵀJ))
120+
__update_Jᵀf!(cache, Val(:JᵀJ))
121121
A, b = __maybe_symmetric(cache.JᵀJ), _vec(cache.Jᵀf)
122122
else
123123
A, b = cache.J, _vec(cache.fu)
@@ -148,6 +148,7 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
148148
return nothing
149149
end
150150

151+
# FIXME: Reinit `JᵀJ` operator if `p` is changed
151152
function __reinit_internal!(cache::GaussNewtonCache;
152153
termination_condition = get_termination_mode(cache.tc_cache_1), kwargs...)
153154
abstol, reltol, tc_cache_1 = init_termination_cache(cache.abstol, cache.reltol,

src/jacobian.jl

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -209,29 +209,18 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
209209
end
210210

211211
# Generic Handling of Krylov Methods for Normal Form Linear Solves
212-
# FIXME: Use MaybeInplace here for efficient matmuls
213-
function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
214-
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
212+
function __update_JᵀJ!(cache::AbstractNonlinearSolveCache)
213+
if !(cache.JᵀJ isa KrylovJᵀJ)
214+
@bb cache.JᵀJ = transpose(cache.J) × cache.J
215+
end
215216
end
216-
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, _, J) = setproperty!(cache, sym, J' * J)
217-
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, _, J) = mul!(getproperty(cache, sym), J', J)
218-
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
219-
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
220217

221-
function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
222-
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)
223-
end
224-
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
225-
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu))
226-
end
227-
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
228-
return mul!(_vec(getproperty(cache, sym1)), J', fu)
229-
end
230-
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
231-
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu))
232-
end
233-
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
234-
return mul!(_vec(getproperty(cache, sym1)), H.Jᵀ, fu)
218+
function __update_Jᵀf!(cache::AbstractNonlinearSolveCache)
219+
if cache.JᵀJ isa KrylovJᵀJ
220+
@bb cache.Jᵀf = cache.JᵀJ.Jᵀ × cache.fu
221+
else
222+
@bb cache.Jᵀf = transpose(cache.J) × vec(cache.fu)
223+
end
235224
end
236225

237226
# Left-Right Multiplication

0 commit comments

Comments
 (0)