Skip to content

Commit 0e3efd7

Browse files
committed
Fix GN
1 parent 21e9ed4 commit 0e3efd7

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/gaussnewton.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
108108
return GaussNewtonCache{iip}(f, alg, u, u_cache, fu, fu_cache, du, dfu, p, uf,
109109
linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
110110
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2,
111-
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)), trace)
111+
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), trace)
112112
end
113113

114114
function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
@@ -117,14 +117,14 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
117117
# Use normal form to solve the Linear Problem
118118
if cache.JᵀJ !== nothing
119119
__update_JᵀJ!(Val{iip}(), cache, :JᵀJ, cache.J)
120-
__update_Jᵀf!(Val{iip}(), cache, :Jᵀf, :JᵀJ, cache.J, cache.fu1)
120+
__update_Jᵀf!(Val{iip}(), cache, :Jᵀf, :JᵀJ, cache.J, cache.fu)
121121
A, b = __maybe_symmetric(cache.JᵀJ), _vec(cache.Jᵀf)
122122
else
123123
A, b = cache.J, _vec(cache.fu)
124124
end
125125

126-
linres = dolinsolve(alg.precs, linsolve; A, b, linu = _vec(du), cache.p,
127-
reltol = cache.abstol)
126+
linres = dolinsolve(cache.alg.precs, cache.linsolve; A, b, linu = _vec(cache.du),
127+
cache.p, reltol = cache.abstol)
128128
cache.linsolve = linres.cache
129129
cache.du = _restructure(cache.du, linres.u)
130130

@@ -136,7 +136,7 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
136136
check_and_update!(cache.tc_cache_1, cache, cache.fu, cache.u, cache.u_cache)
137137
if !cache.force_stop
138138
@bb @. cache.dfu = cache.fu .- cache.dfu
139-
check_and_update!(cache.tc_cache_2, cache, cache.dfu, cache.u, cache.u_prev)
139+
check_and_update!(cache.tc_cache_2, cache, cache.dfu, cache.u, cache.u_cache)
140140
end
141141

142142
@bb copyto!(cache.u_cache, cache.u)

src/utils.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,15 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
188188
return fu
189189
end
190190

191+
function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip}
192+
if iip
193+
f(fu, u, p)
194+
return fu
195+
else
196+
return f(u, p)
197+
end
198+
end
199+
191200
function evaluate_f(cache, u, p)
192201
if isinplace(cache)
193202
cache.prob.f(get_fu(cache), u, p)

0 commit comments

Comments
 (0)