Skip to content

Commit a1278ad

Browse files
committed
add fixes for jac_prototype mismatch sparsity
1 parent 63bc6fd commit a1278ad

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ function prepare_user_sparsity(ad_alg, prob)
106106
sparsity = prob.f.sparsity
107107

108108
if !isnothing(sparsity) && !(ad_alg isa AutoSparse)
109-
if sparsity isa SparseMatrixCSC
109+
if sparsity isa SparseMatrixCSC && !DiffEqBase.has_jac(prob.f)
110110
if prob.f.mass_matrix isa UniformScaling
111111
idxs = diagind(sparsity)
112112
@. @view(sparsity[idxs]) = 1

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,18 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
172172
if DiffEqBase.has_jac(f)
173173
duprev = integrator.duprev
174174
uf = cache.uf
175-
f.jac(J, duprev, uprev, p, uf.α * uf.invγdt, t)
175+
# need to do some jank here to account for sparsity pattern of W
176+
# https://github.com/SciML/OrdinaryDiffEq.jl/issues/2653
177+
178+
# we need to set all nzval to a non-zero number
179+
# otherwise in the following line any zero gets interpreted as a structural zero
180+
integrator.f.jac_prototype.nzval .= 1.0
181+
J .= 1.0 .* integrator.f.jac_prototype
182+
J.nzval .= 0.0
183+
f.jac(J, uprev, p, t)
184+
MM = integrator.f.mass_matrix isa UniformScaling ?
185+
integrator.f.mass_matrix(length(integrator.u)) : integrator.f.mass_matrix
186+
J .= J .+ MM
176187
else
177188
@unpack du1, uf, jac_config = cache
178189
# using `dz` as temporary array
@@ -183,7 +194,17 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
183194
end
184195
else
185196
if DiffEqBase.has_jac(f)
197+
# need to do some jank here to account for sparsity pattern of W
198+
# https://github.com/SciML/OrdinaryDiffEq.jl/issues/2653
199+
200+
# we need to set all nzval to a non-zero number
201+
# otherwise in the following line any zero gets interpreted as a structural zero
202+
integrator.f.jac_prototype.nzval .= 1.0
203+
J .= 1.0 .* integrator.f.jac_prototype
204+
J.nzval .= 0.0
186205
f.jac(J, uprev, p, t)
206+
MM = integrator.f.mass_matrix isa UniformScaling ? integrator.f.mass_matrix(length(integrator.u)) : integrator.f.mass_matrix
207+
J .= J .+ MM
187208
else
188209
@unpack du1, uf, jac_config = cache
189210
uf.f = nlsolve_f(f, alg)

0 commit comments

Comments
 (0)