You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: diffrax/_adjoint.py
+28-1Lines changed: 28 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -1099,7 +1099,34 @@ class ReversibleAdjoint(AbstractAdjoint):
1099
1099
[`diffrax.AbstractReversibleSolver`][].
1100
1100
1101
1101
Gradient calculation is exact (up to floating point errors) and backpropagation
1102
-
is linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps.
1102
+
becomes linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps.
1103
+
1104
+
!!! note
1105
+
1106
+
This adjoint can be less numerically stable than
1107
+
[`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.DirectAdjoint`][].
1108
+
Stability can be largely improved by using [double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)
1109
+
and [smaller/adaptive step sizes](https://docs.kidger.site/diffrax/api/stepsize_controller/).
1110
+
1111
+
??? cite "References"
1112
+
1113
+
For an introduction to reversible backpropagation, see these references:
1114
+
1115
+
```bibtex
1116
+
@article{mccallum2024efficient,
1117
+
title={Efficient, Accurate and Stable Gradients for Neural ODEs},
0 commit comments