Skip to content

Commit 587764c

Browse files
committed
use solver state
1 parent a5916b5 commit 587764c

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

diffrax/_solver/ros3p.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import jax.lax as lax
1919

2020

21-
_SolverState: TypeAlias = None
21+
_SolverState: TypeAlias = VF
2222

2323

2424
@dataclass(frozen=True)
@@ -100,8 +100,8 @@ class Ros3p(AbstractAdaptiveSolver):
100100
tableau: ClassVar[_RosenbrockTableau] = _tableau
101101

102102
def init(self, terms, t0, t1, y0, args) -> _SolverState:
103-
del terms, t0, t1, y0, args
104-
return None
103+
del t1
104+
return terms.vf(t0, y0, args)
105105

106106
def order(self, terms):
107107
return 3
@@ -116,7 +116,7 @@ def step(
116116
solver_state: _SolverState,
117117
made_jump: BoolScalarLike,
118118
) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]:
119-
del made_jump, solver_state
119+
del made_jump
120120

121121
time_derivative = jax.jacfwd(lambda t: terms.vf(t, y0, args))(t0)
122122
control = terms.contr(t0, t1)
@@ -150,18 +150,22 @@ def embed_lower(x):
150150
u = jnp.zeros(
151151
(len(time_derivative), self.tableau.num_stages), dtype=jnp.float64
152152
)
153+
154+
def stage_vf(stage):
155+
return terms.vf(
156+
(t0**ω + α[stage] ** ω * control**ω).ω,
157+
(
158+
y0**ω
159+
+ (a_lower[stage][0] ** ω * u[:, 0] ** ω)
160+
+ (a_lower[stage][1] ** ω * u[:, 1] ** ω)
161+
).ω,
162+
args,
163+
)
153164

154165
def body(_carry, stage):
166+
lax.cond(stage == 0, lambda _: solver_state, stage_vf, stage)
155167
b = (
156-
terms.vf(
157-
(t0**ω + α[stage] ** ω * control**ω).ω,
158-
(
159-
y0**ω
160-
+ (a_lower[stage][0] ** ω * u[:, 0] ** ω)
161-
+ (a_lower[stage][1] ** ω * u[:, 1] ** ω)
162-
).ω,
163-
args,
164-
)
168+
stage_vf(stage)
165169
** ω
166170
+ ((c_lower[stage][0] ** ω / control**ω) * u[:, 0] ** ω)
167171
+ ((c_lower[stage][1] ** ω / control**ω) * u[:, 1] ** ω)
@@ -192,7 +196,7 @@ def body(_carry, stage):
192196
k = jnp.stack((k1, k2))
193197

194198
dense_info = dict(y0=y0, y1=y1, k=k)
195-
return y1, y1_error, dense_info, None, RESULTS.successful
199+
return y1, y1_error, dense_info, k2, RESULTS.successful
196200

197201
def func(
198202
self,

0 commit comments

Comments
 (0)