Skip to content

Commit 63bc6fd

Browse files
committed
add sparsity mismatch test
1 parent da8d130 commit 63bc6fd

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

test/interface/sparsediff_tests.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ using OrdinaryDiffEq
33
using SparseArrays
44
using LinearAlgebra
55
using LinearSolve
6+
import DifferentiationInterface as DI
7+
using SparseConnectivityTracer
8+
using SparseMatrixColorings
69
using ADTypes
710
using Enzyme
811

@@ -82,3 +85,55 @@ for f in [f_oop, f_ip]
8285
end
8386
end
8487

88+
function sparse_f!(du, u, p, t)
89+
du[1] = u[1] + u[2]
90+
du[2] = u[3]^2
91+
return du[3] = u[1]^2
92+
end
93+
94+
backend_allow = AutoSparse(
95+
AutoForwardDiff();
96+
sparsity_detector = TracerSparsityDetector(),
97+
coloring_algorithm = GreedyColoringAlgorithm(; allow_denser = true)
98+
)
99+
100+
backend_no_allow = AutoSparse(
101+
AutoForwardDiff();
102+
sparsity_detector = TracerSparsityDetector(),
103+
coloring_algorithm = GreedyColoringAlgorithm()
104+
)
105+
106+
u = ones(3)
107+
du = zero(u)
108+
p = t = nothing
109+
110+
prep_allow = DI.prepare_jacobian(
111+
sparse_f!, du, backend_allow, u, DI.Constant(p), DI.Constant(t))
112+
prep_no_allow = DI.prepare_jacobian(
113+
sparse_f!, du, backend_no_allow, u, DI.Constant(p), DI.Constant(t))
114+
# this is what the user may typically provide to the ODE problem
115+
116+
function inplace_jac_allow!(J, u, p, t)
117+
return DI.jacobian!(
118+
sparse_f!, zeros(3), J, prep_allow, backend_allow, u, DI.Constant(p), DI.Constant(t))
119+
end
120+
121+
function inplace_jac_no_allow!(J, u, p, t)
122+
return DI.jacobian!(
123+
sparse_f!, zeros(3), J, prep_no_allow, backend_no_allow, u, DI.Constant(p), DI.Constant(t))
124+
end
125+
126+
jac_prototype = similar(sparsity_pattern(prep_allow), eltype(u))
127+
128+
ode_f_allow = ODEFunction(
129+
sparse_f!, jac = inplace_jac_allow!, jac_prototype = jac_prototype)
130+
prob_allow = ODEProblem(ode_f_allow, [1, 1, 1], (0.0, 1.0))
131+
132+
ode_f_no_allow = ODEFunction(
133+
sparse_f!, jac = inplace_jac_no_allow!, jac_prototype = jac_prototype)
134+
prob_no_allow = ODEProblem(ode_f_no_allow, [1, 1, 1], (0.0, 1.0))
135+
136+
sol = solve(prob_allow, Rodas5())
137+
138+
@test_throws DimensionMismatch sol=solve(prob_no_allow, Rodas5())
139+

0 commit comments

Comments
 (0)