|
16 | 16 |
|
17 | 17 | from ._heuristics import is_sde, is_unsafe_sde |
18 | 18 | from ._saveat import save_y, SaveAt, SubSaveAt |
| 19 | +from ._solution import RESULTS |
19 | 20 | from ._solver import ( |
20 | 21 | AbstractItoSolver, |
| 22 | + AbstractReversibleSolver, |
21 | 23 | AbstractRungeKutta, |
22 | 24 | AbstractSRK, |
23 | 25 | AbstractStratonovichSolver, |
@@ -918,3 +920,307 @@ def loop( |
918 | 920 |
|
919 | 921 |
|
920 | 922 | ForwardMode.__init__.__doc__ = """**Arguments:** None""" |
| 923 | + |
| 924 | +# Reversible Adjoint custom vjp computes gradients w.r.t. |
| 925 | +# - y, corresponding to the initial state; |
| 926 | +# - args, corresponding to explicit parameters; |
| 927 | +# - terms, corresponding to implicit parameters as part of the vector field. |
| 928 | + |
| 929 | + |
| 930 | +@eqx.filter_custom_vjp |
| 931 | +def _loop_reversible(y__args__terms, *, self, throw, max_steps, init_state, **kwargs): |
| 932 | + del throw |
| 933 | + y, args, terms = y__args__terms |
| 934 | + init_state = eqx.tree_at(lambda s: s.y, init_state, y) |
| 935 | + del y |
| 936 | + return self._loop( |
| 937 | + args=args, |
| 938 | + terms=terms, |
| 939 | + max_steps=max_steps, |
| 940 | + init_state=init_state, |
| 941 | + inner_while_loop=ft.partial(_inner_loop, kind="lax"), |
| 942 | + outer_while_loop=ft.partial(_outer_loop, kind="lax"), |
| 943 | + **kwargs, |
| 944 | + ) |
| 945 | + |
| 946 | + |
| 947 | +@_loop_reversible.def_fwd |
| 948 | +def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs): |
| 949 | + del perturbed |
| 950 | + final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs) |
| 951 | + init_ts = final_state.reversible_init_ts |
| 952 | + ts = final_state.reversible_ts |
| 953 | + ts_final_index = final_state.reversible_save_index |
| 954 | + y1 = final_state.y |
| 955 | + save_state = final_state.save_state |
| 956 | + solver_state = final_state.solver_state |
| 957 | + return (final_state, aux_stats), ( |
| 958 | + init_ts, |
| 959 | + ts, |
| 960 | + ts_final_index, |
| 961 | + y1, |
| 962 | + save_state, |
| 963 | + solver_state, |
| 964 | + ) |
| 965 | + |
| 966 | + |
| 967 | +@_loop_reversible.def_bwd |
| 968 | +def _loop_reversible_bwd( |
| 969 | + residuals, |
| 970 | + grad_final_state__aux_stats, |
| 971 | + perturbed, |
| 972 | + y__args__terms, |
| 973 | + *, |
| 974 | + self, |
| 975 | + saveat, |
| 976 | + init_state, |
| 977 | + solver, |
| 978 | + event, |
| 979 | + **kwargs, |
| 980 | +): |
| 981 | + assert event is None |
| 982 | + |
| 983 | + del perturbed, self, init_state, kwargs |
| 984 | + init_ts, ts, ts_final_index, y1, save_state, solver_state = residuals |
| 985 | + del residuals |
| 986 | + |
| 987 | + grad_final_state, _ = grad_final_state__aux_stats |
| 988 | + saveat_ts = save_state.ts |
| 989 | + ys = save_state.ys |
| 990 | + saveat_ts_index = save_state.saveat_ts_index - 1 |
| 991 | + grad_ys = grad_final_state.save_state.ys |
| 992 | + grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys) |
| 993 | + |
| 994 | + if saveat.subs.t1: |
| 995 | + grad_y1 = (ω(grad_ys)[-1]).ω |
| 996 | + else: |
| 997 | + grad_y1 = jtu.tree_map(jnp.zeros_like, y1) |
| 998 | + |
| 999 | + if saveat.subs.t0: |
| 1000 | + saveat_ts_index = saveat_ts_index + 1 |
| 1001 | + |
| 1002 | + del grad_final_state, grad_final_state__aux_stats |
| 1003 | + |
| 1004 | + y, args, terms = y__args__terms |
| 1005 | + del y__args__terms |
| 1006 | + |
| 1007 | + diff_state = eqx.filter(solver_state, eqx.is_inexact_array) |
| 1008 | + diff_args = eqx.filter(args, eqx.is_inexact_array) |
| 1009 | + diff_terms = eqx.filter(terms, eqx.is_inexact_array) |
| 1010 | + grad_state = jtu.tree_map(jnp.zeros_like, diff_state) |
| 1011 | + grad_args = jtu.tree_map(jnp.zeros_like, diff_args) |
| 1012 | + grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms) |
| 1013 | + del diff_args, diff_terms |
| 1014 | + |
| 1015 | + def grad_step(state): |
| 1016 | + def forward_step(y0, solver_state, args, terms): |
| 1017 | + y1, _, dense_info, new_solver_state, result = solver.step( |
| 1018 | + terms, t0, t1, y0, args, solver_state, False |
| 1019 | + ) |
| 1020 | + assert result == RESULTS.successful |
| 1021 | + return y1, dense_info, new_solver_state |
| 1022 | + |
| 1023 | + ( |
| 1024 | + saveat_ts_index, |
| 1025 | + ts_index, |
| 1026 | + y1, |
| 1027 | + solver_state, |
| 1028 | + grad_y1, |
| 1029 | + grad_state, |
| 1030 | + grad_args, |
| 1031 | + grad_terms, |
| 1032 | + ) = state |
| 1033 | + |
| 1034 | + t1 = ts[ts_index] |
| 1035 | + t0 = ts[ts_index - 1] |
| 1036 | + |
| 1037 | + # Any ts state required to reverse the forward step |
| 1038 | + # e.g. LeapfrogMidpoint requires tm1 |
| 1039 | + tm1_index = ts_index - 2 |
| 1040 | + tm1 = ts[tm1_index] |
| 1041 | + tm1 = jnp.where(tm1_index >= 0, tm1, t0) |
| 1042 | + ts_state = (tm1,) |
| 1043 | + |
| 1044 | + y0, dense_info, solver_state, result = solver.backward_step( |
| 1045 | + terms, t0, t1, y1, args, ts_state, solver_state, False |
| 1046 | + ) |
| 1047 | + assert result == RESULTS.successful |
| 1048 | + |
| 1049 | + # Pull gradients back through interpolation |
| 1050 | + |
| 1051 | + def interpolate(t, t0, t1, dense_info): |
| 1052 | + interpolator = solver.interpolation_cls(t0=t0, t1=t1, **dense_info) |
| 1053 | + return interpolator.evaluate(t) |
| 1054 | + |
| 1055 | + def _cond_fun(inner_state): |
| 1056 | + saveat_ts_index, _ = inner_state |
| 1057 | + return (saveat_ts[saveat_ts_index] >= t0) & (saveat_ts_index >= 0) |
| 1058 | + |
| 1059 | + def _body_fun(inner_state): |
| 1060 | + saveat_ts_index, grad_dense_info = inner_state |
| 1061 | + t = saveat_ts[saveat_ts_index] |
| 1062 | + grad_y = (ω(grad_ys)[saveat_ts_index]).ω |
| 1063 | + _, interp_vjp = eqx.filter_vjp(interpolate, t, t0, t1, dense_info) |
| 1064 | + _, _, _, dgrad_dense_info = interp_vjp(grad_y) |
| 1065 | + grad_dense_info = eqx.apply_updates(grad_dense_info, dgrad_dense_info) |
| 1066 | + saveat_ts_index = saveat_ts_index - 1 |
| 1067 | + return saveat_ts_index, grad_dense_info |
| 1068 | + |
| 1069 | + grad_dense_info = jtu.tree_map(jnp.zeros_like, dense_info) |
| 1070 | + inner_state = (saveat_ts_index, grad_dense_info) |
| 1071 | + inner_state = eqxi.while_loop(_cond_fun, _body_fun, inner_state, kind="lax") |
| 1072 | + saveat_ts_index, grad_dense_info = inner_state |
| 1073 | + |
| 1074 | + # Pull gradients back through forward step |
| 1075 | + |
| 1076 | + _, vjp_fn = eqx.filter_vjp(forward_step, y0, solver_state, args, terms) |
| 1077 | + grad_y0, grad_state, dgrad_args, dgrad_terms = vjp_fn( |
| 1078 | + (grad_y1, grad_dense_info, grad_state) |
| 1079 | + ) |
| 1080 | + |
| 1081 | + grad_args = eqx.apply_updates(grad_args, dgrad_args) |
| 1082 | + grad_terms = eqx.apply_updates(grad_terms, dgrad_terms) |
| 1083 | + |
| 1084 | + ts_index = ts_index - 1 |
| 1085 | + |
| 1086 | + return ( |
| 1087 | + saveat_ts_index, |
| 1088 | + ts_index, |
| 1089 | + y0, |
| 1090 | + solver_state, |
| 1091 | + grad_y0, |
| 1092 | + grad_state, |
| 1093 | + grad_args, |
| 1094 | + grad_terms, |
| 1095 | + ) |
| 1096 | + |
| 1097 | + def cond_fun(state): |
| 1098 | + ts_index = state[1] |
| 1099 | + return ts_index > 0 |
| 1100 | + |
| 1101 | + state = ( |
| 1102 | + saveat_ts_index, |
| 1103 | + ts_final_index, |
| 1104 | + y1, |
| 1105 | + solver_state, |
| 1106 | + grad_y1, |
| 1107 | + grad_state, |
| 1108 | + grad_args, |
| 1109 | + grad_terms, |
| 1110 | + ) |
| 1111 | + |
| 1112 | + state = jax.lax.while_loop(cond_fun, grad_step, state) |
| 1113 | + _, _, y0, _, grad_y0, grad_state, grad_args, grad_terms = state |
| 1114 | + |
| 1115 | + # Pull solver_state gradients back onto y0, args, terms. |
| 1116 | + |
| 1117 | + init_t0, init_t1 = init_ts |
| 1118 | + _, init_vjp = eqx.filter_vjp(solver.init, terms, init_t0, init_t1, y0, args) |
| 1119 | + dgrad_terms, _, _, dgrad_y0, dgrad_args = init_vjp(grad_state) |
| 1120 | + grad_y0 = eqx.apply_updates(grad_y0, dgrad_y0) |
| 1121 | + grad_terms = eqx.apply_updates(grad_terms, dgrad_terms) |
| 1122 | + grad_args = eqx.apply_updates(grad_args, dgrad_args) |
| 1123 | + |
| 1124 | + return grad_y0, grad_args, grad_terms |
| 1125 | + |
| 1126 | + |
| 1127 | +class ReversibleAdjoint(AbstractAdjoint): |
| 1128 | + """Backpropagate through [`diffrax.diffeqsolve`][] when using a reversible solver |
| 1129 | + [`diffrax.AbstractReversibleSolver`][]. |
| 1130 | +
|
| 1131 | + Gradient calculation is exact (up to floating point errors) and backpropagation |
| 1132 | + becomes linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps. |
| 1133 | +
|
| 1134 | + !!! note |
| 1135 | +
|
| 1136 | + This adjoint can be less numerically stable than |
| 1137 | + [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.DirectAdjoint`][]. |
| 1138 | + Stability can be largely improved by using [double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) |
| 1139 | + and [smaller/adaptive step sizes](https://docs.kidger.site/diffrax/api/stepsize_controller/). |
| 1140 | +
|
| 1141 | + ??? cite "References" |
| 1142 | +
|
| 1143 | + For an introduction to reversible backpropagation, see these references: |
| 1144 | +
|
| 1145 | + ```bibtex |
| 1146 | + @article{mccallum2024efficient, |
| 1147 | + title={Efficient, Accurate and Stable Gradients for Neural ODEs}, |
| 1148 | + author={McCallum, Sam and Foster, James}, |
| 1149 | + journal={arXiv preprint arXiv:2410.11648}, |
| 1150 | + year={2024} |
| 1151 | + } |
| 1152 | +
|
| 1153 | + @phdthesis{kidger2021on, |
| 1154 | + title={{O}n {N}eural {D}ifferential {E}quations}, |
| 1155 | + author={Patrick Kidger}, |
| 1156 | + year={2021}, |
| 1157 | + school={University of Oxford}, |
| 1158 | + } |
| 1159 | + ``` |
| 1160 | + """ |
| 1161 | + |
| 1162 | + def loop( |
| 1163 | + self, |
| 1164 | + *, |
| 1165 | + args, |
| 1166 | + terms, |
| 1167 | + solver, |
| 1168 | + saveat, |
| 1169 | + max_steps, |
| 1170 | + init_state, |
| 1171 | + passed_solver_state, |
| 1172 | + passed_controller_state, |
| 1173 | + event, |
| 1174 | + **kwargs, |
| 1175 | + ): |
| 1176 | + if not isinstance(solver, AbstractReversibleSolver): |
| 1177 | + raise ValueError( |
| 1178 | + "`ReversibleAdjoint` can only be used with an " |
| 1179 | + "`AbstractReversibleSolver`" |
| 1180 | + ) |
| 1181 | + if max_steps is None: |
| 1182 | + raise ValueError( |
| 1183 | + "`max_steps=None` is incompatible with `ReversibleAdjoint`." |
| 1184 | + ) |
| 1185 | + |
| 1186 | + if ( |
| 1187 | + jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) |
| 1188 | + != jtu.tree_structure(0) |
| 1189 | + or saveat.dense |
| 1190 | + or saveat.subs.steps |
| 1191 | + or (saveat.subs.fn is not save_y) |
| 1192 | + ): |
| 1193 | + raise ValueError( |
| 1194 | + "`ReversibleAdjoint` is only compatible with the following `SaveAt` " |
| 1195 | + "properties: `t0`, `t1`, `ts`, `fn=save_y` (default)." |
| 1196 | + ) |
| 1197 | + |
| 1198 | + if event is not None: |
| 1199 | + raise NotImplementedError( |
| 1200 | + "`ReversibleAdjoint` is not compatible with events." |
| 1201 | + ) |
| 1202 | + |
| 1203 | + if is_unsafe_sde(terms): |
| 1204 | + raise ValueError( |
| 1205 | + "`ReversibleAdjoint` does not support `UnsafeBrownianPath`. " |
| 1206 | + "Consider using `VirtualBrownianTree` instead." |
| 1207 | + ) |
| 1208 | + |
| 1209 | + y = init_state.y |
| 1210 | + init_state = eqx.tree_at(lambda s: s.y, init_state, object()) |
| 1211 | + init_state = _nondiff_solver_controller_state( |
| 1212 | + self, init_state, passed_solver_state, passed_controller_state |
| 1213 | + ) |
| 1214 | + |
| 1215 | + final_state, aux_stats = _loop_reversible( |
| 1216 | + (y, args, terms), |
| 1217 | + self=self, |
| 1218 | + saveat=saveat, |
| 1219 | + max_steps=max_steps, |
| 1220 | + init_state=init_state, |
| 1221 | + solver=solver, |
| 1222 | + event=event, |
| 1223 | + **kwargs, |
| 1224 | + ) |
| 1225 | + final_state = _only_transpose_ys(final_state) |
| 1226 | + return final_state, aux_stats |
0 commit comments