Skip to content

Commit 54f132f

Browse files
btabacopybara-github
authored andcommitted
Resolve discrepancy in RK4 between MJX and MJ.
PiperOrigin-RevId: 743185866 Change-Id: I478faca1d8733999cc9a0682e931b19ee05e3843
1 parent 40393f4 commit 54f132f

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

mjx/mujoco/mjx/_src/forward.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def euler(m: Model, d: Data) -> Data:
317317
@named_scope
318318
def rungekutta4(m: Model, d: Data) -> Data:
319319
"""Runge-Kutta explicit order 4 integrator."""
320-
d_t0 = d
320+
d0 = d
321321
# pylint: disable=invalid-name
322322
A, B = _RK4_A, _RK4_B
323323
C = jp.tril(A).sum(axis=0) # C(i) = sum_j A(i,j)
@@ -338,9 +338,9 @@ def f(carry, x):
338338
lambda k: a * k, (kqvel, d.qacc, d.act_dot)
339339
)
340340
# get intermediate RK solutions
341-
kqpos = scan.flat(m, integrate_fn, 'jqv', 'q', m.jnt_type, d_t0.qpos, dqvel)
342-
kact = d_t0.act + dact_dot * m.opt.timestep
343-
kqvel = d_t0.qvel + dqacc * m.opt.timestep
341+
kqpos = scan.flat(m, integrate_fn, 'jqv', 'q', m.jnt_type, d0.qpos, dqvel)
342+
kact = d0.act + dact_dot * m.opt.timestep
343+
kqvel = d0.qvel + dqacc * m.opt.timestep
344344
d = d.replace(qpos=kqpos, qvel=kqvel, act=kact, time=t)
345345
d = forward(m, d)
346346

@@ -352,9 +352,10 @@ def f(carry, x):
352352

353353
abt = jp.vstack([jp.diag(A), B[1:4], T]).T
354354
out, _ = jax.lax.scan(f, (qvel, qacc, act_dot, kqvel, d), abt, unroll=3)
355-
qvel, qacc, act_dot, *_ = out
355+
qvel, qacc, act_dot, _, d1 = out
356356

357-
d = _advance(m, d_t0, act_dot, qacc, qvel)
357+
d = d1.replace(qpos=d0.qpos, qvel=d0.qvel, act=d0.act, time=d0.time)
358+
d = _advance(m, d, act_dot, qacc, qvel)
358359
return d
359360

360361

mjx/mujoco/mjx/_src/forward_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def test_rk4(self):
130130
_assert_attr_eq(d, dx, 'qpos')
131131
_assert_attr_eq(d, dx, 'act')
132132
_assert_attr_eq(d, dx, 'time')
133+
_assert_attr_eq(d, dx, 'xpos')
133134

134135
def test_eulerdamp(self):
135136
m = test_util.load_test_file('pendula.xml')

0 commit comments

Comments
 (0)