Skip to content

Commit bf032d3

Browse files
committed
don't apply MM twice, simpler test
1 parent e8233a5 commit bf032d3

File tree

2 files changed

+31
-39
lines changed

2 files changed

+31
-39
lines changed

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,14 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
177177

178178
# we need to set all nzval to a non-zero number
179179
# otherwise in the following line any zero gets interpreted as a structural zero
180-
integrator.f.jac_prototype.nzval .= 1.0
181-
J .= 1.0 .* integrator.f.jac_prototype
182-
J.nzval .= 0.0
183-
f.jac(J, uprev, p, t)
184-
MM = integrator.f.mass_matrix isa UniformScaling ?
185-
integrator.f.mass_matrix(length(integrator.u)) : integrator.f.mass_matrix
186-
J .= J .+ MM
180+
if !isnothing(integrator.f.jac_prototype)
181+
integrator.f.jac_prototype.nzval .= 1.0
182+
J .= 1.0 .* integrator.f.jac_prototype
183+
J.nzval .= 0.0
184+
f.jac(J, duprev, uprev, p, uf.α * uf.invγdt, t)
185+
else
186+
f.jac(J, duprev, uprev, p, uf.α * uf.invγdt, t)
187+
end
187188
else
188189
@unpack du1, uf, jac_config = cache
189190
# using `dz` as temporary array
@@ -199,12 +200,14 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
199200

200201
# we need to set all nzval to a non-zero number
201202
# otherwise in the following line any zero gets interpreted as a structural zero
202-
integrator.f.jac_prototype.nzval .= 1.0
203-
J .= 1.0 .* integrator.f.jac_prototype
204-
J.nzval .= 0.0
205-
f.jac(J, uprev, p, t)
206-
MM = integrator.f.mass_matrix isa UniformScaling ? integrator.f.mass_matrix(length(integrator.u)) : integrator.f.mass_matrix
207-
J .= J .+ MM
203+
if !isnothing(integrator.f.jac_prototype)
204+
integrator.f.jac_prototype.nzval .= 1.0
205+
J .= 1.0 .* integrator.f.jac_prototype
206+
J.nzval .= 0.0
207+
f.jac(J, uprev, p, t)
208+
else
209+
f.jac(J, uprev, p, t)
210+
end
208211
else
209212
@unpack du1, uf, jac_config = cache
210213
uf.f = nlsolve_f(f, alg)

test/interface/sparsediff_tests.jl

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,35 +87,24 @@ end
8787

8888
# test for https://github.com/SciML/OrdinaryDiffEq.jl/issues/2653#issuecomment-2778430025
8989

90-
function sparse_f!(du, u, p, t)
91-
du[1] = u[1] + u[2]
92-
du[2] = u[3]^2
93-
return du[3] = u[1]^2
94-
end
95-
96-
backend = AutoSparse(
97-
AutoForwardDiff();
98-
sparsity_detector = TracerSparsityDetector(),
99-
coloring_algorithm = GreedyColoringAlgorithm()
100-
)
101-
102-
u = ones(3)
103-
du = zero(u)
104-
p = t = nothing
105-
106-
prep = DI.prepare_jacobian(
107-
sparse_f!, du, backend, u, DI.Constant(p), DI.Constant(t))
108-
# this is what the user may typically provide to the ODE problem
90+
using LinearAlgebra, SparseArrays
91+
using OrdinaryDiffEq
10992

110-
function inplace_jac!(J, u, p, t)
111-
return DI.jacobian!(
112-
sparse_f!, zeros(3), J, prep, backend, u, DI.Constant(p), DI.Constant(t))
93+
function f(du, u, p, t)
94+
du[1] = u[1]
95+
return du
11396
end
11497

115-
jac_prototype = similar(sparsity_pattern(prep), eltype(u))
98+
function jac(J::SparseMatrixCSC, u, p, t)
99+
@assert nnz(J) == 1 # mirrors the strict behavior of SparseMatrixColorings
100+
nonzeros(J)[1] = 1
101+
return J
102+
end
116103

117-
ode_f = ODEFunction(sparse_f!, jac = inplace_jac!, jac_prototype = jac_prototype)
118-
prob = ODEProblem(ode_f, [1, 1, 1], (0.0, 1.0))
104+
u0 = ones(10)
105+
jac_prototype = sparse(Diagonal(vcat(1, zeros(9))))
119106

120-
@test_no_warn sol = solve(prob, Rodas5())
107+
fun = ODEFunction(f; jac, jac_prototype)
108+
prob = ODEProblem(fun, u0, (0.0, 1.0))
109+
@test_nowarn sol = solve(prob, Rodas4(); reltol = 1e-8, abstol = 1e-8)
121110

0 commit comments

Comments
 (0)