@@ -317,7 +317,7 @@ def euler(m: Model, d: Data) -> Data:
317317@named_scope
318318def rungekutta4 (m : Model , d : Data ) -> Data :
319319 """Runge-Kutta explicit order 4 integrator."""
320- d_t0 = d
320+ d0 = d
321321 # pylint: disable=invalid-name
322322 A , B = _RK4_A , _RK4_B
323323 C = jp .tril (A ).sum (axis = 0 ) # C(i) = sum_j A(i,j)
@@ -338,9 +338,9 @@ def f(carry, x):
338338 lambda k : a * k , (kqvel , d .qacc , d .act_dot )
339339 )
340340 # get intermediate RK solutions
341- kqpos = scan .flat (m , integrate_fn , 'jqv' , 'q' , m .jnt_type , d_t0 .qpos , dqvel )
342- kact = d_t0 .act + dact_dot * m .opt .timestep
343- kqvel = d_t0 .qvel + dqacc * m .opt .timestep
341+ kqpos = scan .flat (m , integrate_fn , 'jqv' , 'q' , m .jnt_type , d0 .qpos , dqvel )
342+ kact = d0 .act + dact_dot * m .opt .timestep
343+ kqvel = d0 .qvel + dqacc * m .opt .timestep
344344 d = d .replace (qpos = kqpos , qvel = kqvel , act = kact , time = t )
345345 d = forward (m , d )
346346
@@ -352,9 +352,10 @@ def f(carry, x):
352352
353353 abt = jp .vstack ([jp .diag (A ), B [1 :4 ], T ]).T
354354 out , _ = jax .lax .scan (f , (qvel , qacc , act_dot , kqvel , d ), abt , unroll = 3 )
355- qvel , qacc , act_dot , * _ = out
355+ qvel , qacc , act_dot , _ , d1 = out
356356
357- d = _advance (m , d_t0 , act_dot , qacc , qvel )
357+ d = d1 .replace (qpos = d0 .qpos , qvel = d0 .qvel , act = d0 .act , time = d0 .time )
358+ d = _advance (m , d , act_dot , qacc , qvel )
358359 return d
359360
360361
0 commit comments