@@ -25,19 +25,32 @@ def __call__(self, t, y, args):
2525
2626
2727@eqx .filter_value_and_grad
28- def _loss (y0__args__term , solver , saveat , adjoint , stepsize_controller , dual_y0 ):
28+ def _loss (
29+ y0__args__term ,
30+ solver ,
31+ saveat ,
32+ adjoint ,
33+ stepsize_controller ,
34+ dual_y0 ,
35+ t0_equals_t1 ,
36+ ):
2937 y0 , args , term = y0__args__term
3038
3139 if isinstance (stepsize_controller , diffrax .StepTo ):
3240 dt0 = None
3341 else :
3442 dt0 = 0.01
3543
44+ if t0_equals_t1 :
45+ t1 = 0
46+ else :
47+ t1 = 5
48+
3649 sol = diffrax .diffeqsolve (
3750 term ,
3851 solver ,
3952 t0 = 0 ,
40- t1 = 5 ,
53+ t1 = t1 ,
4154 dt0 = dt0 ,
4255 y0 = y0 ,
4356 args = args ,
@@ -54,7 +67,13 @@ def _loss(y0__args__term, solver, saveat, adjoint, stepsize_controller, dual_y0)
5467
5568
5669def _compare_grads (
57- y0__args__term , base_solver , solver , saveat , stepsize_controller , dual_y0
70+ y0__args__term ,
71+ base_solver ,
72+ solver ,
73+ saveat ,
74+ stepsize_controller ,
75+ dual_y0 = False ,
76+ t0_equals_t1 = False ,
5877):
5978 loss , grads_base = _loss (
6079 y0__args__term ,
@@ -63,6 +82,7 @@ def _compare_grads(
6382 adjoint = diffrax .RecursiveCheckpointAdjoint (),
6483 stepsize_controller = stepsize_controller ,
6584 dual_y0 = dual_y0 ,
85+ t0_equals_t1 = t0_equals_t1 ,
6686 )
6787 loss , grads_reversible = _loss (
6888 y0__args__term ,
@@ -71,6 +91,7 @@ def _compare_grads(
7191 adjoint = diffrax .ReversibleAdjoint (),
7292 stepsize_controller = stepsize_controller ,
7393 dual_y0 = dual_y0 ,
94+ t0_equals_t1 = t0_equals_t1 ,
7495 )
7596 assert tree_allclose (grads_base , grads_reversible , atol = 1e-5 )
7697
@@ -130,9 +151,7 @@ def test_reversible_heun_ode(stepsize_controller, saveat):
130151 args = jnp .array ([0.5 ])
131152 solver = diffrax .ReversibleHeun ()
132153
133- _compare_grads (
134- (y0 , args , terms ), solver , solver , saveat , stepsize_controller , dual_y0 = False
135- )
154+ _compare_grads ((y0 , args , terms ), solver , solver , saveat , stepsize_controller )
136155
137156
138157@pytest .mark .parametrize (
@@ -161,9 +180,7 @@ def test_reversible_heun_sde(stepsize_controller, saveat):
161180 args = jnp .array ([0.5 ])
162181 solver = diffrax .ReversibleHeun ()
163182
164- _compare_grads (
165- (y0 , args , terms ), solver , solver , saveat , stepsize_controller , dual_y0 = False
166- )
183+ _compare_grads ((y0 , args , terms ), solver , solver , saveat , stepsize_controller )
167184
168185
169186@pytest .mark .parametrize (
@@ -189,9 +206,7 @@ def test_leapfrog_midpoint(stepsize_controller, saveat):
189206 args = jnp .array ([0.5 ])
190207 solver = diffrax .LeapfrogMidpoint ()
191208
192- _compare_grads (
193- (y0 , args , terms ), solver , solver , saveat , stepsize_controller , dual_y0 = False
194- )
209+ _compare_grads ((y0 , args , terms ), solver , solver , saveat , stepsize_controller )
195210
196211
197212@pytest .mark .parametrize (
@@ -226,14 +241,7 @@ def test_reversible_explicit(stepsize_controller, saveat):
226241 if saveat .subs .ts is not None :
227242 base_solver = solver
228243
229- _compare_grads (
230- (y0 , args , terms ),
231- base_solver ,
232- solver ,
233- saveat ,
234- stepsize_controller ,
235- dual_y0 = False ,
236- )
244+ _compare_grads ((y0 , args , terms ), base_solver , solver , saveat , stepsize_controller )
237245
238246
239247@pytest .mark .parametrize (
@@ -270,11 +278,33 @@ def test_reversible_sde(stepsize_controller, saveat):
270278 if saveat .subs .ts is not None :
271279 base_solver = solver
272280
281+ _compare_grads ((y0 , args , terms ), base_solver , solver , saveat , stepsize_controller )
282+
283+
284+ @pytest .mark .parametrize (
285+ "saveat" ,
286+ [
287+ diffrax .SaveAt (t0 = True ),
288+ diffrax .SaveAt (t1 = True ),
289+ diffrax .SaveAt (t0 = True , t1 = True ),
290+ ],
291+ )
292+ def test_reversible_t0_equals_t1 (saveat ):
293+ n = 10
294+ y0 = jnp .linspace (1 , 10 , num = n )
295+ key = jr .PRNGKey (10 )
296+ f = VectorField (n , n , n , depth = 4 , key = key )
297+ terms = diffrax .ODETerm (f )
298+ args = jnp .array ([0.5 ])
299+ base_solver = diffrax .Tsit5 ()
300+ solver = diffrax .UReversible (base_solver )
301+ stepsize_controller = diffrax .ConstantStepSize ()
302+
273303 _compare_grads (
274304 (y0 , args , terms ),
275305 base_solver ,
276306 solver ,
277307 saveat ,
278308 stepsize_controller ,
279- dual_y0 = False ,
309+ t0_equals_t1 = True ,
280310 )
0 commit comments