Skip to content

Commit 0a96e2f

Browse files
committed
check diagonal correctly
1 parent f59140f commit 0a96e2f

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

src/broyden.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,15 @@ function perform_step!(cache::GeneralBroydenCache{iip, IJ, UR}) where {iip, IJ,
164164
T = eltype(cache.u)
165165

166166
if IJ === :true_jacobian && cache.stats.nsteps == 0
167-
if __isdiag(cache.J⁻¹) && cache.J⁻¹_cache !== nothing
167+
if UR === :diagonal
168168
cache.J⁻¹_cache = __safe_inv(jacobian!!(cache.J⁻¹_cache, cache))
169169
cache.J⁻¹ = __get_diagonal!!(cache.J⁻¹, cache.J⁻¹_cache)
170170
else
171171
cache.J⁻¹ = __safe_inv(jacobian!!(cache.J⁻¹, cache))
172172
end
173173
end
174174

175-
if __isdiag(cache.J⁻¹)
175+
if UR === :diagonal
176176
@bb @. cache.du = cache.J⁻¹ * cache.fu
177177
else
178178
@bb cache.du = cache.J⁻¹ × vec(cache.fu)
@@ -197,7 +197,7 @@ function perform_step!(cache::GeneralBroydenCache{iip, IJ, UR}) where {iip, IJ,
197197
return nothing
198198
end
199199
if IJ === :true_jacobian
200-
if __isdiag(cache.J⁻¹) && cache.J⁻¹_cache !== nothing
200+
if UR === :diagonal
201201
cache.J⁻¹_cache = __safe_inv(jacobian!!(cache.J⁻¹_cache, cache))
202202
cache.J⁻¹ = __get_diagonal!!(cache.J⁻¹, cache.J⁻¹_cache)
203203
else

src/klement.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function perform_step!(cache::GeneralKlementCache{iip, IJ}) where {iip, IJ}
198198
cache.resets += 1
199199
end
200200

201-
if __isdiag(cache.J)
201+
if IJ === :true_jacobian_diagonal || IJ === :identity
202202
@bb @. cache.du = cache.fu / cache.J
203203
else
204204
# u = u - J \ fu
@@ -223,7 +223,7 @@ function perform_step!(cache::GeneralKlementCache{iip, IJ}) where {iip, IJ}
223223

224224
# Update the Jacobian
225225
@bb cache.du .*= -1
226-
if __isdiag(cache.J)
226+
if IJ === :true_jacobian_diagonal || IJ === :identity
227227
@bb @. cache.Jdu = (cache.J^2) * (cache.du^2)
228228
@bb @. cache.J += ((cache.fu - cache.fu_cache_2 - cache.J * cache.du) /
229229
ifelse(iszero(cache.Jdu), T(1e-5), cache.Jdu)) * cache.du *

src/utils.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,3 @@ end
556556
@inline __diag(x::AbstractMatrix) = diag(x)
557557
@inline __diag(x::AbstractVector) = x
558558
@inline __diag(x::Number) = x
559-
560-
@inline __isdiag(::AbstractVector) = true
561-
@inline __isdiag(::Number) = true
562-
@inline __isdiag(::AbstractMatrix) = false

0 commit comments

Comments
 (0)