Skip to content

Commit ae01942

Browse files
Sam McCallumsammccallum
authored andcommitted
AbstractReversibleSolver + ReversibleAdjoint
add reversible testing testing AbstractReversibleSolver + ReversibleAdjoint allow arbitrary interpolation unpacking over indexing jax while loop collapse saveat ValueErrors remove statonovich solver condition remove unused returns from AbstractReversibleSolver backward_step add test and remove messy benchmark add wrapped solver + tests made_jump=True for both solver steps improve docstrings AbstractSolver and docstring note about SDEs add AbstractReversibleSolver to public API newline in docstrings return RESULTS from reversible backward_step restrict Reversible to AbstractERK and check result in adjoint correct tprev and tnext of solver init switch to linear interpolation and y0,y1 dense_info name UReversible various doc formatting changes AbstractReversibleSolver check add disable_fsal property to AbstractRungeKutta and use in UReversible allow t0 != 0 Handle StepTo controller t0==t1 branch
1 parent 6694c86 commit ae01942

File tree

11 files changed

+1040
-13
lines changed

11 files changed

+1040
-13
lines changed

diffrax/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
ForwardMode as ForwardMode,
88
ImplicitAdjoint as ImplicitAdjoint,
99
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
10+
ReversibleAdjoint as ReversibleAdjoint,
1011
)
1112
from ._autocitation import citation as citation, citation_rules as citation_rules
1213
from ._brownian import (
@@ -75,6 +76,7 @@
7576
AbstractFosterLangevinSRK as AbstractFosterLangevinSRK,
7677
AbstractImplicitSolver as AbstractImplicitSolver,
7778
AbstractItoSolver as AbstractItoSolver,
79+
AbstractReversibleSolver as AbstractReversibleSolver,
7880
AbstractRungeKutta as AbstractRungeKutta,
7981
AbstractSDIRK as AbstractSDIRK,
8082
AbstractSolver as AbstractSolver,
@@ -117,6 +119,7 @@
117119
StochasticButcherTableau as StochasticButcherTableau,
118120
StratonovichMilstein as StratonovichMilstein,
119121
Tsit5 as Tsit5,
122+
UReversible as UReversible,
120123
)
121124
from ._step_size_controller import (
122125
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,

diffrax/_adjoint.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
from ._heuristics import is_sde, is_unsafe_sde
1818
from ._saveat import save_y, SaveAt, SubSaveAt
19+
from ._solution import RESULTS
1920
from ._solver import (
2021
AbstractItoSolver,
22+
AbstractReversibleSolver,
2123
AbstractRungeKutta,
2224
AbstractSRK,
2325
AbstractStratonovichSolver,
@@ -918,3 +920,307 @@ def loop(
918920

919921

920922
ForwardMode.__init__.__doc__ = """**Arguments:** None"""
923+
924+
# Reversible Adjoint custom vjp computes gradients w.r.t.
925+
# - y, corresponding to the initial state;
926+
# - args, corresponding to explicit parameters;
927+
# - terms, corresponding to implicit parameters as part of the vector field.
928+
929+
930+
@eqx.filter_custom_vjp
931+
def _loop_reversible(y__args__terms, *, self, throw, max_steps, init_state, **kwargs):
932+
del throw
933+
y, args, terms = y__args__terms
934+
init_state = eqx.tree_at(lambda s: s.y, init_state, y)
935+
del y
936+
return self._loop(
937+
args=args,
938+
terms=terms,
939+
max_steps=max_steps,
940+
init_state=init_state,
941+
inner_while_loop=ft.partial(_inner_loop, kind="lax"),
942+
outer_while_loop=ft.partial(_outer_loop, kind="lax"),
943+
**kwargs,
944+
)
945+
946+
947+
@_loop_reversible.def_fwd
948+
def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs):
949+
del perturbed
950+
final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs)
951+
init_ts = final_state.reversible_init_ts
952+
ts = final_state.reversible_ts
953+
ts_final_index = final_state.reversible_save_index
954+
y1 = final_state.y
955+
save_state = final_state.save_state
956+
solver_state = final_state.solver_state
957+
return (final_state, aux_stats), (
958+
init_ts,
959+
ts,
960+
ts_final_index,
961+
y1,
962+
save_state,
963+
solver_state,
964+
)
965+
966+
967+
@_loop_reversible.def_bwd
968+
def _loop_reversible_bwd(
969+
residuals,
970+
grad_final_state__aux_stats,
971+
perturbed,
972+
y__args__terms,
973+
*,
974+
self,
975+
saveat,
976+
init_state,
977+
solver,
978+
event,
979+
**kwargs,
980+
):
981+
assert event is None
982+
983+
del perturbed, self, init_state, kwargs
984+
init_ts, ts, ts_final_index, y1, save_state, solver_state = residuals
985+
del residuals
986+
987+
grad_final_state, _ = grad_final_state__aux_stats
988+
saveat_ts = save_state.ts
989+
ys = save_state.ys
990+
saveat_ts_index = save_state.saveat_ts_index - 1
991+
grad_ys = grad_final_state.save_state.ys
992+
grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys)
993+
994+
if saveat.subs.t1:
995+
grad_y1 = (ω(grad_ys)[-1]).ω
996+
else:
997+
grad_y1 = jtu.tree_map(jnp.zeros_like, y1)
998+
999+
if saveat.subs.t0:
1000+
saveat_ts_index = saveat_ts_index + 1
1001+
1002+
del grad_final_state, grad_final_state__aux_stats
1003+
1004+
y, args, terms = y__args__terms
1005+
del y__args__terms
1006+
1007+
diff_state = eqx.filter(solver_state, eqx.is_inexact_array)
1008+
diff_args = eqx.filter(args, eqx.is_inexact_array)
1009+
diff_terms = eqx.filter(terms, eqx.is_inexact_array)
1010+
grad_state = jtu.tree_map(jnp.zeros_like, diff_state)
1011+
grad_args = jtu.tree_map(jnp.zeros_like, diff_args)
1012+
grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms)
1013+
del diff_args, diff_terms
1014+
1015+
def grad_step(state):
1016+
def forward_step(y0, solver_state, args, terms):
1017+
y1, _, dense_info, new_solver_state, result = solver.step(
1018+
terms, t0, t1, y0, args, solver_state, False
1019+
)
1020+
assert result == RESULTS.successful
1021+
return y1, dense_info, new_solver_state
1022+
1023+
(
1024+
saveat_ts_index,
1025+
ts_index,
1026+
y1,
1027+
solver_state,
1028+
grad_y1,
1029+
grad_state,
1030+
grad_args,
1031+
grad_terms,
1032+
) = state
1033+
1034+
t1 = ts[ts_index]
1035+
t0 = ts[ts_index - 1]
1036+
1037+
# Any ts state required to reverse the forward step
1038+
# e.g. LeapfrogMidpoint requires tm1
1039+
tm1_index = ts_index - 2
1040+
tm1 = ts[tm1_index]
1041+
tm1 = jnp.where(tm1_index >= 0, tm1, t0)
1042+
ts_state = (tm1,)
1043+
1044+
y0, dense_info, solver_state, result = solver.backward_step(
1045+
terms, t0, t1, y1, args, ts_state, solver_state, False
1046+
)
1047+
assert result == RESULTS.successful
1048+
1049+
# Pull gradients back through interpolation
1050+
1051+
def interpolate(t, t0, t1, dense_info):
1052+
interpolator = solver.interpolation_cls(t0=t0, t1=t1, **dense_info)
1053+
return interpolator.evaluate(t)
1054+
1055+
def _cond_fun(inner_state):
1056+
saveat_ts_index, _ = inner_state
1057+
return (saveat_ts[saveat_ts_index] >= t0) & (saveat_ts_index >= 0)
1058+
1059+
def _body_fun(inner_state):
1060+
saveat_ts_index, grad_dense_info = inner_state
1061+
t = saveat_ts[saveat_ts_index]
1062+
grad_y = (ω(grad_ys)[saveat_ts_index]).ω
1063+
_, interp_vjp = eqx.filter_vjp(interpolate, t, t0, t1, dense_info)
1064+
_, _, _, dgrad_dense_info = interp_vjp(grad_y)
1065+
grad_dense_info = eqx.apply_updates(grad_dense_info, dgrad_dense_info)
1066+
saveat_ts_index = saveat_ts_index - 1
1067+
return saveat_ts_index, grad_dense_info
1068+
1069+
grad_dense_info = jtu.tree_map(jnp.zeros_like, dense_info)
1070+
inner_state = (saveat_ts_index, grad_dense_info)
1071+
inner_state = eqxi.while_loop(_cond_fun, _body_fun, inner_state, kind="lax")
1072+
saveat_ts_index, grad_dense_info = inner_state
1073+
1074+
# Pull gradients back through forward step
1075+
1076+
_, vjp_fn = eqx.filter_vjp(forward_step, y0, solver_state, args, terms)
1077+
grad_y0, grad_state, dgrad_args, dgrad_terms = vjp_fn(
1078+
(grad_y1, grad_dense_info, grad_state)
1079+
)
1080+
1081+
grad_args = eqx.apply_updates(grad_args, dgrad_args)
1082+
grad_terms = eqx.apply_updates(grad_terms, dgrad_terms)
1083+
1084+
ts_index = ts_index - 1
1085+
1086+
return (
1087+
saveat_ts_index,
1088+
ts_index,
1089+
y0,
1090+
solver_state,
1091+
grad_y0,
1092+
grad_state,
1093+
grad_args,
1094+
grad_terms,
1095+
)
1096+
1097+
def cond_fun(state):
1098+
ts_index = state[1]
1099+
return ts_index > 0
1100+
1101+
state = (
1102+
saveat_ts_index,
1103+
ts_final_index,
1104+
y1,
1105+
solver_state,
1106+
grad_y1,
1107+
grad_state,
1108+
grad_args,
1109+
grad_terms,
1110+
)
1111+
1112+
state = jax.lax.while_loop(cond_fun, grad_step, state)
1113+
_, _, y0, _, grad_y0, grad_state, grad_args, grad_terms = state
1114+
1115+
# Pull solver_state gradients back onto y0, args, terms.
1116+
1117+
init_t0, init_t1 = init_ts
1118+
_, init_vjp = eqx.filter_vjp(solver.init, terms, init_t0, init_t1, y0, args)
1119+
dgrad_terms, _, _, dgrad_y0, dgrad_args = init_vjp(grad_state)
1120+
grad_y0 = eqx.apply_updates(grad_y0, dgrad_y0)
1121+
grad_terms = eqx.apply_updates(grad_terms, dgrad_terms)
1122+
grad_args = eqx.apply_updates(grad_args, dgrad_args)
1123+
1124+
return grad_y0, grad_args, grad_terms
1125+
1126+
1127+
class ReversibleAdjoint(AbstractAdjoint):
1128+
"""Backpropagate through [`diffrax.diffeqsolve`][] when using a reversible solver
1129+
[`diffrax.AbstractReversibleSolver`][].
1130+
1131+
Gradient calculation is exact (up to floating point errors) and backpropagation
1132+
becomes linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps.
1133+
1134+
!!! note
1135+
1136+
This adjoint can be less numerically stable than
1137+
[`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.DirectAdjoint`][].
1138+
Stability can be largely improved by using [double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)
1139+
and [smaller/adaptive step sizes](https://docs.kidger.site/diffrax/api/stepsize_controller/).
1140+
1141+
??? cite "References"
1142+
1143+
For an introduction to reversible backpropagation, see these references:
1144+
1145+
```bibtex
1146+
@article{mccallum2024efficient,
1147+
title={Efficient, Accurate and Stable Gradients for Neural ODEs},
1148+
author={McCallum, Sam and Foster, James},
1149+
journal={arXiv preprint arXiv:2410.11648},
1150+
year={2024}
1151+
}
1152+
1153+
@phdthesis{kidger2021on,
1154+
title={{O}n {N}eural {D}ifferential {E}quations},
1155+
author={Patrick Kidger},
1156+
year={2021},
1157+
school={University of Oxford},
1158+
}
1159+
```
1160+
"""
1161+
1162+
def loop(
1163+
self,
1164+
*,
1165+
args,
1166+
terms,
1167+
solver,
1168+
saveat,
1169+
max_steps,
1170+
init_state,
1171+
passed_solver_state,
1172+
passed_controller_state,
1173+
event,
1174+
**kwargs,
1175+
):
1176+
if not isinstance(solver, AbstractReversibleSolver):
1177+
raise ValueError(
1178+
"`ReversibleAdjoint` can only be used with an "
1179+
"`AbstractReversibleSolver`"
1180+
)
1181+
if max_steps is None:
1182+
raise ValueError(
1183+
"`max_steps=None` is incompatible with `ReversibleAdjoint`."
1184+
)
1185+
1186+
if (
1187+
jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat)
1188+
!= jtu.tree_structure(0)
1189+
or saveat.dense
1190+
or saveat.subs.steps
1191+
or (saveat.subs.fn is not save_y)
1192+
):
1193+
raise ValueError(
1194+
"`ReversibleAdjoint` is only compatible with the following `SaveAt` "
1195+
"properties: `t0`, `t1`, `ts`, `fn=save_y` (default)."
1196+
)
1197+
1198+
if event is not None:
1199+
raise NotImplementedError(
1200+
"`ReversibleAdjoint` is not compatible with events."
1201+
)
1202+
1203+
if is_unsafe_sde(terms):
1204+
raise ValueError(
1205+
"`ReversibleAdjoint` does not support `UnsafeBrownianPath`. "
1206+
"Consider using `VirtualBrownianTree` instead."
1207+
)
1208+
1209+
y = init_state.y
1210+
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
1211+
init_state = _nondiff_solver_controller_state(
1212+
self, init_state, passed_solver_state, passed_controller_state
1213+
)
1214+
1215+
final_state, aux_stats = _loop_reversible(
1216+
(y, args, terms),
1217+
self=self,
1218+
saveat=saveat,
1219+
max_steps=max_steps,
1220+
init_state=init_state,
1221+
solver=solver,
1222+
event=event,
1223+
**kwargs,
1224+
)
1225+
final_state = _only_transpose_ys(final_state)
1226+
return final_state, aux_stats

0 commit comments

Comments
 (0)