Skip to content

Commit fe65f02

Browse files
committed
use made jump
1 parent 587764c commit fe65f02

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

diffrax/_solver/ros3p.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import jax.numpy as jnp
77
import lineax as lx
88
from equinox.internal import ω
9+
import equinox.internal as eqxi
910

1011
from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y
1112
from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation
@@ -116,8 +117,6 @@ def step(
116117
solver_state: _SolverState,
117118
made_jump: BoolScalarLike,
118119
) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]:
119-
del made_jump
120-
121120
time_derivative = jax.jacfwd(lambda t: terms.vf(t, y0, args))(t0)
122121
control = terms.contr(t0, t1)
123122

@@ -150,22 +149,34 @@ def embed_lower(x):
150149
u = jnp.zeros(
151150
(len(time_derivative), self.tableau.num_stages), dtype=jnp.float64
152151
)
153-
154-
def stage_vf(stage):
155-
return terms.vf(
156-
(t0**ω + α[stage] ** ω * control**ω).ω,
157-
(
158-
y0**ω
159-
+ (a_lower[stage][0] ** ω * u[:, 0] ** ω)
160-
+ (a_lower[stage][1] ** ω * u[:, 1] ** ω)
161-
).ω,
162-
args,
163-
)
152+
153+
start_stage = [0]
154+
155+
def use_saved_vf():
156+
stage_0_vf = solver_state
157+
stage_0_b = (
158+
stage_0_vf**ω + (control**ω * γ[0] ** ω * time_derivative**ω)
159+
).ω
160+
stage_0_u = lx.linear_solve(A, stage_0_b).value
161+
u.at[:, 0].set(stage_0_u)
162+
start_stage[0] = 1
163+
164+
if made_jump is False:
165+
use_saved_vf()
166+
else:
167+
lax.cond(eqxi.unvmap_any(made_jump), use_saved_vf, lambda: None)
164168

165169
def body(_carry, stage):
166-
lax.cond(stage == 0, lambda _: solver_state, stage_vf, stage)
167170
b = (
168-
stage_vf(stage)
171+
terms.vf(
172+
(t0**ω + α[stage] ** ω * control**ω).ω,
173+
(
174+
y0**ω
175+
+ (a_lower[stage][0] ** ω * u[:, 0] ** ω)
176+
+ (a_lower[stage][1] ** ω * u[:, 1] ** ω)
177+
).ω,
178+
args,
179+
)
169180
** ω
170181
+ ((c_lower[stage][0] ** ω / control**ω) * u[:, 0] ** ω)
171182
+ ((c_lower[stage][1] ** ω / control**ω) * u[:, 1] ** ω)
@@ -175,7 +186,7 @@ def body(_carry, stage):
175186
u.at[:, stage].set(stage_u)
176187
return _carry, stage
177188

178-
lax.scan(f=body, init=0, xs=jnp.arange(self.tableau.num_stages))
189+
lax.scan(f=body, init=0, xs=jnp.arange(start_stage[0], self.tableau.num_stages))
179190

180191
y1 = (
181192
y0**ω

0 commit comments

Comments
 (0)