66import jax .numpy as jnp
77import lineax as lx
88from equinox .internal import ω
9+ import equinox .internal as eqxi
910
1011from .._custom_types import Args , BoolScalarLike , DenseInfo , RealScalarLike , VF , Y
1112from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation
@@ -116,8 +117,6 @@ def step(
116117 solver_state : _SolverState ,
117118 made_jump : BoolScalarLike ,
118119 ) -> tuple [Y , Y , DenseInfo , _SolverState , RESULTS ]:
119- del made_jump
120-
121120 time_derivative = jax .jacfwd (lambda t : terms .vf (t , y0 , args ))(t0 )
122121 control = terms .contr (t0 , t1 )
123122
@@ -150,22 +149,34 @@ def embed_lower(x):
150149 u = jnp .zeros (
151150 (len (time_derivative ), self .tableau .num_stages ), dtype = jnp .float64
152151 )
153-
154- def stage_vf (stage ):
155- return terms .vf (
156- (t0 ** ω + α [stage ] ** ω * control ** ω ).ω ,
157- (
158- y0 ** ω
159- + (a_lower [stage ][0 ] ** ω * u [:, 0 ] ** ω )
160- + (a_lower [stage ][1 ] ** ω * u [:, 1 ] ** ω )
161- ).ω ,
162- args ,
163- )
152+
153+ start_stage = [0 ]
154+
155+ def use_saved_vf ():
156+ stage_0_vf = solver_state
157+ stage_0_b = (
158+ stage_0_vf ** ω + (control ** ω * γ [0 ] ** ω * time_derivative ** ω )
159+ ).ω
160+ stage_0_u = lx .linear_solve (A , stage_0_b ).value
161+ u .at [:, 0 ].set (stage_0_u )
162+ start_stage [0 ] = 1
163+
164+ if made_jump is False :
165+ use_saved_vf ()
166+ else :
167+ lax .cond (eqxi .unvmap_any (made_jump ), use_saved_vf , lambda : None )
164168
165169 def body (_carry , stage ):
166- lax .cond (stage == 0 , lambda _ : solver_state , stage_vf , stage )
167170 b = (
168- stage_vf (stage )
171+ terms .vf (
172+ (t0 ** ω + α [stage ] ** ω * control ** ω ).ω ,
173+ (
174+ y0 ** ω
175+ + (a_lower [stage ][0 ] ** ω * u [:, 0 ] ** ω )
176+ + (a_lower [stage ][1 ] ** ω * u [:, 1 ] ** ω )
177+ ).ω ,
178+ args ,
179+ )
169180 ** ω
170181 + ((c_lower [stage ][0 ] ** ω / control ** ω ) * u [:, 0 ] ** ω )
171182 + ((c_lower [stage ][1 ] ** ω / control ** ω ) * u [:, 1 ] ** ω )
@@ -175,7 +186,7 @@ def body(_carry, stage):
175186 u .at [:, stage ].set (stage_u )
176187 return _carry , stage
177188
178- lax .scan (f = body , init = 0 , xs = jnp .arange (self .tableau .num_stages ))
189+ lax .scan (f = body , init = 0 , xs = jnp .arange (start_stage [ 0 ], self .tableau .num_stages ))
179190
180191 y1 = (
181192 y0 ** ω
0 commit comments