Commit 1ab5d23
Mark n_corrector_steps as static for allowing for backwards mode differentiation of predictor corrector.
jax.lax.fori_loop requires a static upper and lower to be backwards mode differentiable.
PiperOrigin-RevId: 7940676401 parent 68d7481 commit 1ab5d23
2 files changed
+4
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
55 | 55 | | |
56 | 56 | | |
57 | 57 | | |
58 | | - | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
59 | 61 | | |
60 | 62 | | |
61 | 63 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
36 | 36 | | |
37 | 37 | | |
38 | 38 | | |
39 | | - | |
| 39 | + | |
0 commit comments