8787 p
8888 uf
8989 J⁻¹
90+ J⁻¹_cache
9091 J⁻¹dfu
9192 inv_alpha
9293 alpha_initial
@@ -123,12 +124,23 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralBroyd
123124 alg = get_concrete_algorithm (alg_, prob)
124125 uf, _, J, fu_cache, jac_cache, du = jacobian_caches (alg, f, u, p, Val (iip);
125126 lininit = Val (false ))
126- J⁻¹ = J
127- else
127+ if UR === :diagonal
128+ J⁻¹_cache = J
129+ J⁻¹ = __diag (J)
130+ else
131+ J⁻¹_cache = nothing
132+ J⁻¹ = J
133+ end
134+ elseif IJ === :identity
128135 alg = alg_
129136 @bb du = similar (u)
130- uf, fu_cache, jac_cache = nothing , nothing , nothing
131- J⁻¹ = __init_identity_jacobian (u, fu, inv_alpha)
137+ uf, fu_cache, jac_cache, J⁻¹_cache = nothing , nothing , nothing , nothing
138+ if UR === :diagonal
139+ J⁻¹ = one .(fu)
140+ @bb J⁻¹ .*= inv_alpha
141+ else
142+ J⁻¹ = __init_identity_jacobian (u, fu, inv_alpha)
143+ end
132144 end
133145
134146 reset_tolerance = alg. reset_tolerance === nothing ? sqrt (eps (real (eltype (u)))) :
@@ -145,9 +157,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralBroyd
145157 uses_jac_inverse = Val (true ), kwargs... )
146158
147159 return GeneralBroydenCache {iip, IJ, UR} (f, alg, u, u_cache, du, fu, fu_cache, dfu, p,
148- uf, J⁻¹, J⁻¹dfu, inv_alpha, alg. alpha, false , 0 , alg. max_resets, maxiters ,
149- internalnorm, ReturnCode. Default, abstol, reltol, reset_tolerance, reset_check ,
150- jac_cache, prob, NLStats (1 , 0 , 0 , 0 , 0 ),
160+ uf, J⁻¹, J⁻¹_cache, J⁻¹ dfu, inv_alpha, alg. alpha, false , 0 , alg. max_resets,
161+ maxiters, internalnorm, ReturnCode. Default, abstol, reltol, reset_tolerance,
162+ reset_check, jac_cache, prob, NLStats (1 , 0 , 0 , 0 , 0 ),
151163 init_linesearch_cache (alg. linesearch, f, u, p, fu, Val (iip)), tc_cache, trace)
152164end
153165
@@ -158,7 +170,11 @@ function perform_step!(cache::GeneralBroydenCache{iip, IJ, UR}) where {iip, IJ,
158170 cache. J⁻¹ = __safe_inv (jacobian!! (cache. J⁻¹, cache)) # This allocates
159171 end
160172
161- @bb cache. du = cache. J⁻¹ × vec (cache. fu)
173+ if __isdiag (cache. J⁻¹)
174+ @bb @. cache. du = cache. J⁻¹ * cache. fu
175+ else
176+ @bb cache. du = cache. J⁻¹ × vec (cache. fu)
177+ end
162178 α = perform_linesearch! (cache. ls_cache, cache. u, cache. du)
163179 @bb axpy! (- α, cache. du, cache. u)
164180
@@ -179,7 +195,12 @@ function perform_step!(cache::GeneralBroydenCache{iip, IJ, UR}) where {iip, IJ,
179195 return nothing
180196 end
181197 if IJ === :true_jacobian
182- cache. J⁻¹ = __safe_inv (jacobian!! (cache. J⁻¹, cache))
198+ if __isdiag (cache. J⁻¹)
199+ cache. J⁻¹_cache = __safe_inv (jacobian!! (cache. J⁻¹_cache, cache))
200+ cache. J⁻¹ = __get_diagonal!! (cache. J⁻¹, cache. J⁻¹_cache)
201+ else
202+ cache. J⁻¹ = __safe_inv (jacobian!! (cache. J⁻¹, cache))
203+ end
183204 else
184205 cache. inv_alpha = __initial_inv_alpha (cache. inv_alpha, cache. alpha_initial,
185206 cache. u, cache. fu, cache. internalnorm)
@@ -188,18 +209,26 @@ function perform_step!(cache::GeneralBroydenCache{iip, IJ, UR}) where {iip, IJ,
188209 cache. resets += 1
189210 else
190211 @bb cache. du .*= - 1
191- @bb cache. J⁻¹dfu = cache. J⁻¹ × vec (cache. dfu)
192212 if UR === :good_broyden
213+ @bb cache. J⁻¹dfu = cache. J⁻¹ × vec (cache. dfu)
193214 @bb cache. u_cache = transpose (cache. J⁻¹) × vec (cache. du)
194215 denom = dot (cache. du, cache. J⁻¹dfu)
195216 @bb @. cache. du = (cache. du - cache. J⁻¹dfu) /
196217 ifelse (iszero (denom), T (1e-5 ), denom)
197218 @bb cache. J⁻¹ += vec (cache. du) × transpose (_vec (cache. u_cache))
198219 elseif UR === :bad_broyden
220+ @bb cache. J⁻¹dfu = cache. J⁻¹ × vec (cache. dfu)
199221 dfu_norm = cache. internalnorm (cache. dfu)^ 2
200222 @bb @. cache. du = (cache. du - cache. J⁻¹dfu) /
201223 ifelse (iszero (dfu_norm), T (1e-5 ), dfu_norm)
202224 @bb cache. J⁻¹ += vec (cache. du) × transpose (_vec (cache. dfu))
225+ elseif UR === :diagonal
226+ @bb @. cache. J⁻¹dfu = cache. du * cache. J⁻¹ * cache. dfu
227+ denom = sum (cache. J⁻¹dfu)
228+ @bb @. cache. J⁻¹ += (cache. du - cache. J⁻¹ * cache. dfu) * cache. du * cache. J⁻¹ /
229+ ifelse (iszero (denom), T (1e-5 ), denom)
230+ else
231+ error (" update_rule = Val(:$(UR) ) is not implemented for Broyden." )
203232 end
204233 end
205234
0 commit comments