Skip to content

Commit 3c5f725

Browse files
committed
Diagonal Broyden Update Implementation
1 parent 78b0f89 commit 3c5f725

File tree

4 files changed

+76
-33
lines changed

4 files changed

+76
-33
lines changed

src/broyden.jl

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ end
8787
p
8888
uf
8989
J⁻¹
90+
J⁻¹_cache
9091
J⁻¹dfu
9192
inv_alpha
9293
alpha_initial
@@ -123,12 +124,23 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralBroyd
123124
alg = get_concrete_algorithm(alg_, prob)
124125
uf, _, J, fu_cache, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
125126
lininit = Val(false))
126-
J⁻¹ = J
127-
else
127+
if UR === :diagonal
128+
J⁻¹_cache = J
129+
J⁻¹ = __diag(J)
130+
else
131+
J⁻¹_cache = nothing
132+
J⁻¹ = J
133+
end
134+
elseif IJ === :identity
128135
alg = alg_
129136
@bb du = similar(u)
130-
uf, fu_cache, jac_cache = nothing, nothing, nothing
131-
J⁻¹ = __init_identity_jacobian(u, fu, inv_alpha)
137+
uf, fu_cache, jac_cache, J⁻¹_cache = nothing, nothing, nothing, nothing
138+
if UR === :diagonal
139+
J⁻¹ = one.(fu)
140+
@bb J⁻¹ .*= inv_alpha
141+
else
142+
J⁻¹ = __init_identity_jacobian(u, fu, inv_alpha)
143+
end
132144
end
133145

134146
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
@@ -145,9 +157,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralBroyd
145157
uses_jac_inverse = Val(true), kwargs...)
146158

147159
return GeneralBroydenCache{iip, IJ, UR}(f, alg, u, u_cache, du, fu, fu_cache, dfu, p,
148-
uf, J⁻¹, J⁻¹dfu, inv_alpha, alg.alpha, false, 0, alg.max_resets, maxiters,
149-
internalnorm, ReturnCode.Default, abstol, reltol, reset_tolerance, reset_check,
150-
jac_cache, prob, NLStats(1, 0, 0, 0, 0),
160+
uf, J⁻¹, J⁻¹_cache, J⁻¹dfu, inv_alpha, alg.alpha, false, 0, alg.max_resets,
161+
maxiters, internalnorm, ReturnCode.Default, abstol, reltol, reset_tolerance,
162+
reset_check, jac_cache, prob, NLStats(1, 0, 0, 0, 0),
151163
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
152164
end
153165

@@ -158,7 +170,11 @@ function perform_step!(cache::GeneralBroydenCache{iip, IJ, UR}) where {iip, IJ,
158170
cache.J⁻¹ = __safe_inv(jacobian!!(cache.J⁻¹, cache)) # This allocates
159171
end
160172

161-
@bb cache.du = cache.J⁻¹ × vec(cache.fu)
173+
if __isdiag(cache.J⁻¹)
174+
@bb @. cache.du = cache.J⁻¹ * cache.fu
175+
else
176+
@bb cache.du = cache.J⁻¹ × vec(cache.fu)
177+
end
162178
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
163179
@bb axpy!(-α, cache.du, cache.u)
164180

@@ -179,7 +195,12 @@ function perform_step!(cache::GeneralBroydenCache{iip, IJ, UR}) where {iip, IJ,
179195
return nothing
180196
end
181197
if IJ === :true_jacobian
182-
cache.J⁻¹ = __safe_inv(jacobian!!(cache.J⁻¹, cache))
198+
if __isdiag(cache.J⁻¹)
199+
cache.J⁻¹_cache = __safe_inv(jacobian!!(cache.J⁻¹_cache, cache))
200+
cache.J⁻¹ = __get_diagonal!!(cache.J⁻¹, cache.J⁻¹_cache)
201+
else
202+
cache.J⁻¹ = __safe_inv(jacobian!!(cache.J⁻¹, cache))
203+
end
183204
else
184205
cache.inv_alpha = __initial_inv_alpha(cache.inv_alpha, cache.alpha_initial,
185206
cache.u, cache.fu, cache.internalnorm)
@@ -188,18 +209,26 @@ function perform_step!(cache::GeneralBroydenCache{iip, IJ, UR}) where {iip, IJ,
188209
cache.resets += 1
189210
else
190211
@bb cache.du .*= -1
191-
@bb cache.J⁻¹dfu = cache.J⁻¹ × vec(cache.dfu)
192212
if UR === :good_broyden
213+
@bb cache.J⁻¹dfu = cache.J⁻¹ × vec(cache.dfu)
193214
@bb cache.u_cache = transpose(cache.J⁻¹) × vec(cache.du)
194215
denom = dot(cache.du, cache.J⁻¹dfu)
195216
@bb @. cache.du = (cache.du - cache.J⁻¹dfu) /
196217
ifelse(iszero(denom), T(1e-5), denom)
197218
@bb cache.J⁻¹ += vec(cache.du) × transpose(_vec(cache.u_cache))
198219
elseif UR === :bad_broyden
220+
@bb cache.J⁻¹dfu = cache.J⁻¹ × vec(cache.dfu)
199221
dfu_norm = cache.internalnorm(cache.dfu)^2
200222
@bb @. cache.du = (cache.du - cache.J⁻¹dfu) /
201223
ifelse(iszero(dfu_norm), T(1e-5), dfu_norm)
202224
@bb cache.J⁻¹ += vec(cache.du) × transpose(_vec(cache.dfu))
225+
elseif UR === :diagonal
226+
@bb @. cache.J⁻¹dfu = cache.du * cache.J⁻¹ * cache.dfu
227+
denom = sum(cache.J⁻¹dfu)
228+
@bb @. cache.J⁻¹ += (cache.du - cache.J⁻¹ * cache.dfu) * cache.du * cache.J⁻¹ /
229+
ifelse(iszero(denom), T(1e-5), denom)
230+
else
231+
error("update_rule = Val(:$(UR)) is not implemented for Broyden.")
203232
end
204233
end
205234

src/klement.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function perform_step!(cache::GeneralKlementCache{iip, IJ}) where {iip, IJ}
198198
cache.resets += 1
199199
end
200200

201-
if cache.J isa AbstractVector || cache.J isa Number
201+
if __isdiag(cache.J)
202202
@bb @. cache.du = cache.fu / cache.J
203203
else
204204
# u = u - J \ fu
@@ -223,7 +223,7 @@ function perform_step!(cache::GeneralKlementCache{iip, IJ}) where {iip, IJ}
223223

224224
# Update the Jacobian
225225
@bb cache.du .*= -1
226-
if cache.J isa AbstractVector || cache.J isa Number
226+
if __isdiag(cache.J)
227227
@bb @. cache.Jdu = (cache.J^2) * (cache.du^2)
228228
@bb @. cache.J += ((cache.fu - cache.fu_cache_2 - cache.J * cache.du) /
229229
ifelse(iszero(cache.Jdu), T(1e-5), cache.Jdu)) * cache.du *

src/trace.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ function NonlinearSolveTraceEntry(iteration, fu, δu, J, u)
134134
end
135135

136136
__cond(J::AbstractMatrix) = cond(J)
137+
__cond(J::SVector) = __cond(Diagonal(MVector(J)))
137138
__cond(J::AbstractVector) = __cond(Diagonal(J))
139+
__cond(J::ApplyArray) = __cond(J.f(J.args...))
138140
__cond(J) = -1 # Covers cases where `J` is a Operator, nothing, etc.
139141

140142
__copy(x::AbstractArray) = copy(x)

src/utils.jl

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,36 @@ LazyArrays.applied_ndims(::typeof(__zero), x) = ndims(x)
399399
LazyArrays.applied_size(::typeof(__zero), x) = size(x)
400400
LazyArrays.applied_axes(::typeof(__zero), x) = axes(x)
401401

402+
# Safe Inverse: Try to use `inv` but if lu fails use `pinv`
403+
@inline __safe_inv(A::Number) = pinv(A)
404+
@inline __safe_inv(A::AbstractMatrix) = pinv(A)
405+
@inline __safe_inv(A::AbstractVector) = __safe_inv(Diagonal(A)).diag
406+
@inline __safe_inv(A::ApplyArray) = __safe_inv(A.f(A.args...))
407+
@inline function __safe_inv(A::StridedMatrix{T}) where {T}
408+
LinearAlgebra.checksquare(A)
409+
if istriu(A)
410+
A_ = UpperTriangular(A)
411+
issingular = any(iszero, @view(A_[diagind(A_)]))
412+
!issingular && return triu!(parent(inv(A_)))
413+
elseif istril(A)
414+
A_ = LowerTriangular(A)
415+
issingular = any(iszero, @view(A_[diagind(A_)]))
416+
!issingular && return tril!(parent(inv(A_)))
417+
else
418+
F = lu(A; check = false)
419+
if issuccess(F)
420+
Ai = LinearAlgebra.inv!(F)
421+
return convert(typeof(parent(Ai)), Ai)
422+
end
423+
end
424+
return pinv(A)
425+
end
426+
427+
LazyArrays.applied_eltype(::typeof(__safe_inv), x) = eltype(x)
428+
LazyArrays.applied_ndims(::typeof(__safe_inv), x) = ndims(x)
429+
LazyArrays.applied_size(::typeof(__safe_inv), x) = size(x)
430+
LazyArrays.applied_axes(::typeof(__safe_inv), x) = axes(x)
431+
402432
# SparseAD --> NonSparseAD
403433
@inline __get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
404434
@inline __get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
@@ -527,24 +557,6 @@ end
527557
@inline __diag(x::AbstractVector) = x
528558
@inline __diag(x::Number) = x
529559

530-
# Safe Inverse: Try to use `inv` but if lu fails use `pinv`
531-
__safe_inv(A::AbstractMatrix) = pinv(A)
532-
function __safe_inv(A::StridedMatrix{T}) where {T}
533-
LinearAlgebra.checksquare(A)
534-
if istriu(A)
535-
A_ = UpperTriangular(A)
536-
issingular = any(iszero, @view(A_[diagind(A_)]))
537-
!issingular && return triu!(parent(inv(A_)))
538-
elseif istril(A)
539-
A_ = LowerTriangular(A)
540-
issingular = any(iszero, @view(A_[diagind(A_)]))
541-
!issingular && return tril!(parent(inv(A_)))
542-
else
543-
F = lu(A; check = false)
544-
if issuccess(F)
545-
Ai = LinearAlgebra.inv!(F)
546-
return convert(typeof(parent(Ai)), Ai)
547-
end
548-
end
549-
return pinv(A)
550-
end
560+
@inline __isdiag(::AbstractVector) = true
561+
@inline __isdiag(::Number) = true
562+
@inline __isdiag(::AbstractMatrix) = false

0 commit comments

Comments
 (0)