1818import jax .lax as lax
1919
2020
21- _SolverState : TypeAlias = None
21+ _SolverState : TypeAlias = VF
2222
2323
2424@dataclass (frozen = True )
@@ -100,8 +100,8 @@ class Ros3p(AbstractAdaptiveSolver):
100100 tableau : ClassVar [_RosenbrockTableau ] = _tableau
101101
102102 def init (self , terms , t0 , t1 , y0 , args ) -> _SolverState :
103- del terms , t0 , t1 , y0 , args
104- return None
103+ del t1
104+ return terms . vf ( t0 , y0 , args )
105105
106106 def order (self , terms ):
107107 return 3
@@ -116,7 +116,7 @@ def step(
116116 solver_state : _SolverState ,
117117 made_jump : BoolScalarLike ,
118118 ) -> tuple [Y , Y , DenseInfo , _SolverState , RESULTS ]:
119- del made_jump , solver_state
119+ del made_jump
120120
121121 time_derivative = jax .jacfwd (lambda t : terms .vf (t , y0 , args ))(t0 )
122122 control = terms .contr (t0 , t1 )
@@ -150,18 +150,22 @@ def embed_lower(x):
150150 u = jnp .zeros (
151151 (len (time_derivative ), self .tableau .num_stages ), dtype = jnp .float64
152152 )
153+
154+ def stage_vf (stage ):
155+ return terms .vf (
156+ (t0 ** ω + α [stage ] ** ω * control ** ω ).ω ,
157+ (
158+ y0 ** ω
159+ + (a_lower [stage ][0 ] ** ω * u [:, 0 ] ** ω )
160+ + (a_lower [stage ][1 ] ** ω * u [:, 1 ] ** ω )
161+ ).ω ,
162+ args ,
163+ )
153164
154165 def body (_carry , stage ):
166+ lax .cond (stage == 0 , lambda _ : solver_state , stage_vf , stage )
155167 b = (
156- terms .vf (
157- (t0 ** ω + α [stage ] ** ω * control ** ω ).ω ,
158- (
159- y0 ** ω
160- + (a_lower [stage ][0 ] ** ω * u [:, 0 ] ** ω )
161- + (a_lower [stage ][1 ] ** ω * u [:, 1 ] ** ω )
162- ).ω ,
163- args ,
164- )
168+ stage_vf (stage )
165169 ** ω
166170 + ((c_lower [stage ][0 ] ** ω / control ** ω ) * u [:, 0 ] ** ω )
167171 + ((c_lower [stage ][1 ] ** ω / control ** ω ) * u [:, 1 ] ** ω )
@@ -192,7 +196,7 @@ def body(_carry, stage):
192196 k = jnp .stack ((k1 , k2 ))
193197
194198 dense_info = dict (y0 = y0 , y1 = y1 , k = k )
195- return y1 , y1_error , dense_info , None , RESULTS .successful
199+ return y1 , y1_error , dense_info , k2 , RESULTS .successful
196200
197201 def func (
198202 self ,
0 commit comments