Skip to content

Commit 1730a0b

Browse files
committed
t0==t1 branch
1 parent e0323b6 commit 1730a0b

File tree

3 files changed

+79
-22
lines changed

3 files changed

+79
-22
lines changed

diffrax/_adjoint.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1101,7 +1101,6 @@ def cond_fun(state):
11011101

11021102
state = jax.lax.while_loop(cond_fun, grad_step, state)
11031103
_, _, y0, _, grad_y0, grad_state, grad_args, grad_terms = state
1104-
jax.debug.print("{}", y0)
11051104

11061105
# Pull solver_state gradients back onto y0, args, terms.
11071106

diffrax/_integrate.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,34 @@ def _make_ys(out, old_outs):
870870
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
871871
)
872872

873+
# if t0 == t1 and we are using diffrax.ReversibleAdjoint then we need to update the
874+
# reversible_ts and reversible_ts_index to get correct gradients
875+
def _reversible_info_if_t0_equals_t1(reversible_ts, reversible_save_index):
876+
reversible_ts = eqxi.buffer_at_set(final_state.reversible_ts, 1, t0)
877+
reversible_save_index += 1
878+
return reversible_ts, reversible_save_index
879+
880+
reversible_ts, reversible_save_index = jax.lax.cond(
881+
eqxi.unvmap_any(t0 == t1),
882+
lambda __ts, __index: jax.lax.cond(
883+
t0 == t1,
884+
lambda _ts, _index: _reversible_info_if_t0_equals_t1(_ts, _index),
885+
lambda _ts, _index: (_ts, _index),
886+
__ts,
887+
__index,
888+
),
889+
lambda __ts, __index: (__ts, __index),
890+
final_state.reversible_ts,
891+
final_state.reversible_save_index,
892+
)
893+
894+
final_state = eqx.tree_at(
895+
lambda s: (s.reversible_ts, s.reversible_save_index),
896+
final_state,
897+
(reversible_ts, reversible_save_index),
898+
is_leaf=_is_none,
899+
)
900+
873901
final_state = _handle_static(final_state)
874902
result = RESULTS.where(cond_fun(final_state), RESULTS.max_steps_reached, result)
875903
aux_stats = dict() # TODO: put something in here?

test/test_reversible.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,32 @@ def __call__(self, t, y, args):
2525

2626

2727
@eqx.filter_value_and_grad
28-
def _loss(y0__args__term, solver, saveat, adjoint, stepsize_controller, dual_y0):
28+
def _loss(
29+
y0__args__term,
30+
solver,
31+
saveat,
32+
adjoint,
33+
stepsize_controller,
34+
dual_y0,
35+
t0_equals_t1,
36+
):
2937
y0, args, term = y0__args__term
3038

3139
if isinstance(stepsize_controller, diffrax.StepTo):
3240
dt0 = None
3341
else:
3442
dt0 = 0.01
3543

44+
if t0_equals_t1:
45+
t1 = 0
46+
else:
47+
t1 = 5
48+
3649
sol = diffrax.diffeqsolve(
3750
term,
3851
solver,
3952
t0=0,
40-
t1=5,
53+
t1=t1,
4154
dt0=dt0,
4255
y0=y0,
4356
args=args,
@@ -54,7 +67,13 @@ def _loss(y0__args__term, solver, saveat, adjoint, stepsize_controller, dual_y0)
5467

5568

5669
def _compare_grads(
57-
y0__args__term, base_solver, solver, saveat, stepsize_controller, dual_y0
70+
y0__args__term,
71+
base_solver,
72+
solver,
73+
saveat,
74+
stepsize_controller,
75+
dual_y0=False,
76+
t0_equals_t1=False,
5877
):
5978
loss, grads_base = _loss(
6079
y0__args__term,
@@ -63,6 +82,7 @@ def _compare_grads(
6382
adjoint=diffrax.RecursiveCheckpointAdjoint(),
6483
stepsize_controller=stepsize_controller,
6584
dual_y0=dual_y0,
85+
t0_equals_t1=t0_equals_t1,
6686
)
6787
loss, grads_reversible = _loss(
6888
y0__args__term,
@@ -71,6 +91,7 @@ def _compare_grads(
7191
adjoint=diffrax.ReversibleAdjoint(),
7292
stepsize_controller=stepsize_controller,
7393
dual_y0=dual_y0,
94+
t0_equals_t1=t0_equals_t1,
7495
)
7596
assert tree_allclose(grads_base, grads_reversible, atol=1e-5)
7697

@@ -130,9 +151,7 @@ def test_reversible_heun_ode(stepsize_controller, saveat):
130151
args = jnp.array([0.5])
131152
solver = diffrax.ReversibleHeun()
132153

133-
_compare_grads(
134-
(y0, args, terms), solver, solver, saveat, stepsize_controller, dual_y0=False
135-
)
154+
_compare_grads((y0, args, terms), solver, solver, saveat, stepsize_controller)
136155

137156

138157
@pytest.mark.parametrize(
@@ -161,9 +180,7 @@ def test_reversible_heun_sde(stepsize_controller, saveat):
161180
args = jnp.array([0.5])
162181
solver = diffrax.ReversibleHeun()
163182

164-
_compare_grads(
165-
(y0, args, terms), solver, solver, saveat, stepsize_controller, dual_y0=False
166-
)
183+
_compare_grads((y0, args, terms), solver, solver, saveat, stepsize_controller)
167184

168185

169186
@pytest.mark.parametrize(
@@ -189,9 +206,7 @@ def test_leapfrog_midpoint(stepsize_controller, saveat):
189206
args = jnp.array([0.5])
190207
solver = diffrax.LeapfrogMidpoint()
191208

192-
_compare_grads(
193-
(y0, args, terms), solver, solver, saveat, stepsize_controller, dual_y0=False
194-
)
209+
_compare_grads((y0, args, terms), solver, solver, saveat, stepsize_controller)
195210

196211

197212
@pytest.mark.parametrize(
@@ -226,14 +241,7 @@ def test_reversible_explicit(stepsize_controller, saveat):
226241
if saveat.subs.ts is not None:
227242
base_solver = solver
228243

229-
_compare_grads(
230-
(y0, args, terms),
231-
base_solver,
232-
solver,
233-
saveat,
234-
stepsize_controller,
235-
dual_y0=False,
236-
)
244+
_compare_grads((y0, args, terms), base_solver, solver, saveat, stepsize_controller)
237245

238246

239247
@pytest.mark.parametrize(
@@ -270,11 +278,33 @@ def test_reversible_sde(stepsize_controller, saveat):
270278
if saveat.subs.ts is not None:
271279
base_solver = solver
272280

281+
_compare_grads((y0, args, terms), base_solver, solver, saveat, stepsize_controller)
282+
283+
284+
@pytest.mark.parametrize(
285+
"saveat",
286+
[
287+
diffrax.SaveAt(t0=True),
288+
diffrax.SaveAt(t1=True),
289+
diffrax.SaveAt(t0=True, t1=True),
290+
],
291+
)
292+
def test_reversible_t0_equals_t1(saveat):
293+
n = 10
294+
y0 = jnp.linspace(1, 10, num=n)
295+
key = jr.PRNGKey(10)
296+
f = VectorField(n, n, n, depth=4, key=key)
297+
terms = diffrax.ODETerm(f)
298+
args = jnp.array([0.5])
299+
base_solver = diffrax.Tsit5()
300+
solver = diffrax.UReversible(base_solver)
301+
stepsize_controller = diffrax.ConstantStepSize()
302+
273303
_compare_grads(
274304
(y0, args, terms),
275305
base_solver,
276306
solver,
277307
saveat,
278308
stepsize_controller,
279-
dual_y0=False,
309+
t0_equals_t1=True,
280310
)

0 commit comments

Comments
 (0)