@@ -3,6 +3,9 @@ using OrdinaryDiffEq
3
3
using SparseArrays
4
4
using LinearAlgebra
5
5
using LinearSolve
6
+ import DifferentiationInterface as DI
7
+ using SparseConnectivityTracer
8
+ using SparseMatrixColorings
6
9
using ADTypes
7
10
using Enzyme
8
11
@@ -82,3 +85,55 @@ for f in [f_oop, f_ip]
82
85
end
83
86
end
84
87
88
+ function sparse_f! (du, u, p, t)
89
+ du[1 ] = u[1 ] + u[2 ]
90
+ du[2 ] = u[3 ]^ 2
91
+ return du[3 ] = u[1 ]^ 2
92
+ end
93
+
94
+ backend_allow = AutoSparse (
95
+ AutoForwardDiff ();
96
+ sparsity_detector = TracerSparsityDetector (),
97
+ coloring_algorithm = GreedyColoringAlgorithm (; allow_denser = true )
98
+ )
99
+
100
+ backend_no_allow = AutoSparse (
101
+ AutoForwardDiff ();
102
+ sparsity_detector = TracerSparsityDetector (),
103
+ coloring_algorithm = GreedyColoringAlgorithm ()
104
+ )
105
+
106
+ u = ones (3 )
107
+ du = zero (u)
108
+ p = t = nothing
109
+
110
+ prep_allow = DI. prepare_jacobian (
111
+ sparse_f!, du, backend_allow, u, DI. Constant (p), DI. Constant (t))
112
+ prep_no_allow = DI. prepare_jacobian (
113
+ sparse_f!, du, backend_no_allow, u, DI. Constant (p), DI. Constant (t))
114
+ # this is what the user may typically provide to the ODE problem
115
+
116
+ function inplace_jac_allow! (J, u, p, t)
117
+ return DI. jacobian! (
118
+ sparse_f!, zeros (3 ), J, prep_allow, backend_allow, u, DI. Constant (p), DI. Constant (t))
119
+ end
120
+
121
+ function inplace_jac_no_allow! (J, u, p, t)
122
+ return DI. jacobian! (
123
+ sparse_f!, zeros (3 ), J, prep_no_allow, backend_no_allow, u, DI. Constant (p), DI. Constant (t))
124
+ end
125
+
126
+ jac_prototype = similar (sparsity_pattern (prep_allow), eltype (u))
127
+
128
+ ode_f_allow = ODEFunction (
129
+ sparse_f!, jac = inplace_jac_allow!, jac_prototype = jac_prototype)
130
+ prob_allow = ODEProblem (ode_f_allow, [1 , 1 , 1 ], (0.0 , 1.0 ))
131
+
132
+ ode_f_no_allow = ODEFunction (
133
+ sparse_f!, jac = inplace_jac_no_allow!, jac_prototype = jac_prototype)
134
+ prob_no_allow = ODEProblem (ode_f_no_allow, [1 , 1 , 1 ], (0.0 , 1.0 ))
135
+
136
+ sol = solve (prob_allow, Rodas5 ())
137
+
138
+ @test_throws DimensionMismatch sol= solve (prob_no_allow, Rodas5 ())
139
+
0 commit comments