Skip to content

Commit d158c11

Browse files
committed
improve docstrings
1 parent 39075ea commit d158c11

File tree

2 files changed

+77
-8
lines changed

2 files changed

+77
-8
lines changed

diffrax/_adjoint.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,34 @@ class ReversibleAdjoint(AbstractAdjoint):
10991099
[`diffrax.AbstractReversibleSolver`][].
11001100
11011101
Gradient calculation is exact (up to floating point errors) and backpropagation
1102-
is linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps.
1102+
becomes linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps.
1103+
1104+
!!! note
1105+
1106+
This adjoint can be less numerically stable than
1107+
[`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.DirectAdjoint`][].
1108+
Stability can be largely improved by using [double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)
1109+
and [smaller/adaptive step sizes](https://docs.kidger.site/diffrax/api/stepsize_controller/).
1110+
1111+
??? cite "References"
1112+
1113+
For an introduction to reversible backpropagation, see these references:
1114+
1115+
```bibtex
1116+
@article{mccallum2024efficient,
1117+
title={Efficient, Accurate and Stable Gradients for Neural ODEs},
1118+
author={McCallum, Sam and Foster, James},
1119+
journal={arXiv preprint arXiv:2410.11648},
1120+
year={2024}
1121+
}
1122+
1123+
@phdthesis{kidger2021on,
1124+
title={{O}n {N}eural {D}ifferential {E}quations},
1125+
author={Patrick Kidger},
1126+
year={2021},
1127+
school={University of Oxford},
1128+
}
1129+
```
11031130
"""
11041131

11051132
def loop(

diffrax/_solver/reversible.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .._solution import RESULTS, update_result
99
from .._solver.base import (
1010
AbstractReversibleSolver,
11-
AbstractSolver,
11+
AbstractStratonovichSolver,
1212
AbstractWrappedSolver,
1313
)
1414
from .._term import AbstractTerm
@@ -26,12 +26,54 @@ class Reversible(
2626
"""
2727
Reversible solver method.
2828
29-
Allows any solver ([`diffrax.AbstractSolver`][]) to be made algebraically
30-
reversible.
29+
Allows any solver ([`diffrax.AbstractStratonovichSolver`][]) to be made
30+
algebraically reversible.
31+
32+
**Arguments:**
33+
- `solver`: base solver to be made reversible
34+
- `coupling_parameter`: determines coupling between the two evolving solutions.
35+
Must be within the range `0 < coupling_parameter < 1`. Unless you need finer control
36+
over stability, the default value of `0.999` should be sufficient.
37+
38+
??? cite "References"
39+
40+
This method was developed in:
41+
42+
```bibtex
43+
@article{mccallum2024efficient,
44+
title={Efficient, Accurate and Stable Gradients for Neural ODEs},
45+
author={McCallum, Sam and Foster, James},
46+
journal={arXiv preprint arXiv:2410.11648},
47+
year={2024}
48+
}
49+
```
50+
51+
And built on previous work by:
52+
53+
```bibtex
54+
@article{kidger2021efficient,
55+
title={Efficient and accurate gradients for neural sdes},
56+
author={Kidger, Patrick and Foster, James and Li, Xuechen Chen and Lyons,
57+
Terry},
58+
journal={Advances in Neural Information Processing Systems},
59+
volume={34},
60+
pages={18747--18761},
61+
year={2021}
62+
}
63+
64+
@article{zhuang2021mali,
65+
title={Mali: A memory efficient and reverse accurate integrator for neural
66+
odes},
67+
author={Zhuang, Juntang and Dvornek, Nicha C and Tatikonda, Sekhar and
68+
Duncan, James S},
69+
journal={arXiv preprint arXiv:2102.04668},
70+
year={2021}
71+
}
72+
```
3173
"""
3274

33-
solver: AbstractSolver
34-
l: float = 0.999
75+
solver: AbstractStratonovichSolver
76+
coupling_parameter: float = 0.999
3577

3678
@property
3779
def interpolation_cls(self): # pyright: ignore
@@ -85,7 +127,7 @@ def step(
85127
step_z0, _, dense_info, original_solver_state, result1 = self.solver.step(
86128
terms, t0, t1, z0, args, original_solver_state, True
87129
)
88-
y1 = (self.l * (ω(y0) - ω(z0)) + ω(step_z0)).ω
130+
y1 = (self.coupling_parameter * (ω(y0) - ω(z0)) + ω(step_z0)).ω
89131

90132
step_y1, y_error, _, _, result2 = self.solver.step(
91133
terms, t1, t0, y1, args, original_solver_state, True
@@ -115,7 +157,7 @@ def backward_step(
115157
step_z0, _, dense_info, _, _ = self.solver.step(
116158
terms, t0, t1, z0, args, original_solver_state, True
117159
)
118-
y0 = ((1 / self.l) * (ω(y1) - ω(step_z0)) + ω(z0)).ω
160+
y0 = ((1 / self.coupling_parameter) * (ω(y1) - ω(step_z0)) + ω(z0)).ω
119161
solver_state = (original_solver_state, z0)
120162

121163
return y0, dense_info, solver_state

0 commit comments

Comments
 (0)