Skip to content

Commit 9024d2f

Browse files
committed
return RESULTS from reversible backward_step
1 parent e983b46 commit 9024d2f

File tree

6 files changed

+23
-21
lines changed

6 files changed

+23
-21
lines changed

diffrax/_adjoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,7 @@ def forward_step(y0, solver_state, args, terms):
10131013
t1 = ts[ts_index]
10141014
t0 = ts[ts_index - 1]
10151015

1016-
y0, dense_info, solver_state = solver.backward_step(
1016+
y0, dense_info, solver_state, result = solver.backward_step(
10171017
terms, t0, t1, y1, args, solver_state, False
10181018
)
10191019

diffrax/_solver/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def backward_step(
371371
args: Args,
372372
solver_state: _SolverState,
373373
made_jump: BoolScalarLike,
374-
) -> tuple[Y, DenseInfo, _SolverState]:
374+
) -> tuple[Y, DenseInfo, _SolverState, RESULTS]:
375375
"""
376376
Make a single backward step with the reversible solver.
377377

diffrax/_solver/leapfrog_midpoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ class LeapfrogMidpoint(AbstractReversibleSolver):
5252
"""
5353

5454
term_structure: ClassVar = AbstractTerm
55-
interpolation_cls: ClassVar[
56-
Callable[..., LocalLinearInterpolation]
57-
] = LocalLinearInterpolation
55+
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
56+
LocalLinearInterpolation
57+
)
5858

5959
def order(self, terms):
6060
return 2
@@ -101,7 +101,7 @@ def backward_step(
101101
args: Args,
102102
solver_state: _SolverState,
103103
made_jump: BoolScalarLike,
104-
) -> tuple[Y, DenseInfo, _SolverState]:
104+
) -> tuple[Y, DenseInfo, _SolverState, RESULTS]:
105105
del made_jump
106106
t0, y0, dt = solver_state
107107
tm1 = t0 - dt
@@ -114,7 +114,7 @@ def backward_step(
114114
solver_state = jax.lax.cond(
115115
tm1 > 0, lambda _: (tm1, ym1, dt), lambda _: (t0, y0, dt), None
116116
)
117-
return y0, dense_info, solver_state
117+
return y0, dense_info, solver_state, RESULTS.successful
118118

119119
def func(self, terms: AbstractTerm, t0: RealScalarLike, y0: Y, args: Args) -> VF:
120120
return terms.vf(t0, y0, args)

diffrax/_solver/reversible.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,19 +153,21 @@ def backward_step(
153153
args: Args,
154154
solver_state: _SolverState,
155155
made_jump: BoolScalarLike,
156-
) -> tuple[Y, DenseInfo, _SolverState]:
156+
) -> tuple[Y, DenseInfo, _SolverState, RESULTS]:
157157
original_solver_state, z1 = solver_state
158-
step_y1, _, _, original_solver_state, _ = self.solver.step(
158+
step_y1, _, _, original_solver_state, result1 = self.solver.step(
159159
terms, t1, t0, y1, args, original_solver_state, True
160160
)
161161
z0 = (ω(z1) - ω(y1) + ω(step_y1)).ω
162-
step_z0, _, dense_info, _, _ = self.solver.step(
162+
step_z0, _, dense_info, _, result2 = self.solver.step(
163163
terms, t0, t1, z0, args, original_solver_state, True
164164
)
165165
y0 = ((1 / self.coupling_parameter) * (ω(y1) - ω(step_z0)) + ω(z0)).ω
166+
166167
solver_state = (original_solver_state, z0)
168+
result = update_result(result1, result2)
167169

168-
return y0, dense_info, solver_state
170+
return y0, dense_info, solver_state, result
169171

170172
def func(
171173
self, terms: PyTree[AbstractTerm], t0: RealScalarLike, y0: Y, args: Args

diffrax/_solver/reversible_heun.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ class ReversibleHeun(
4949
"""
5050

5151
term_structure: ClassVar = AbstractTerm
52-
interpolation_cls: ClassVar[
53-
Callable[..., LocalLinearInterpolation]
54-
] = LocalLinearInterpolation # TODO use something better than this?
52+
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
53+
LocalLinearInterpolation # TODO use something better than this?
54+
)
5555

5656
def order(self, terms):
5757
return 2
@@ -104,7 +104,7 @@ def backward_step(
104104
args: Args,
105105
solver_state: _SolverState,
106106
made_jump: BoolScalarLike,
107-
) -> tuple[Y, DenseInfo, _SolverState]:
107+
) -> tuple[Y, DenseInfo, _SolverState, RESULTS]:
108108
yhat1, vf1 = solver_state
109109

110110
control = terms.contr(t0, t1)
@@ -114,7 +114,7 @@ def backward_step(
114114

115115
dense_info = dict(y0=y0, y1=y1)
116116
solver_state = (yhat0, vf0)
117-
return y0, dense_info, solver_state
117+
return y0, dense_info, solver_state, RESULTS.successful
118118

119119
def func(self, terms: AbstractTerm, t0: RealScalarLike, y0: Y, args: Args) -> VF:
120120
return terms.vf(t0, y0, args)

diffrax/_solver/semi_implicit_euler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ class SemiImplicitEuler(AbstractReversibleSolver):
3333
"""
3434

3535
term_structure: ClassVar = (AbstractTerm, AbstractTerm)
36-
interpolation_cls: ClassVar[
37-
Callable[..., LocalLinearInterpolation]
38-
] = LocalLinearInterpolation
36+
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
37+
LocalLinearInterpolation
38+
)
3939

4040
def order(self, terms):
4141
return 1
@@ -83,7 +83,7 @@ def backward_step(
8383
args: Args,
8484
solver_state: _SolverState,
8585
made_jump: BoolScalarLike,
86-
) -> tuple[tuple[Ya, Yb], DenseInfo, _SolverState]:
86+
) -> tuple[tuple[Ya, Yb], DenseInfo, _SolverState, RESULTS]:
8787
del solver_state, made_jump
8888

8989
term_1, term_2 = terms
@@ -96,7 +96,7 @@ def backward_step(
9696

9797
y0 = (y0_1, y0_2)
9898
dense_info = dict(y0=y0, y1=y1)
99-
return y0, dense_info, None
99+
return y0, dense_info, None, RESULTS.successful
100100

101101
def func(
102102
self,

0 commit comments

Comments
 (0)