Skip to content

Commit 51c489f

Browse files
thowellcopybara-github
authored andcommitted
Add inverse dynamics to MJX.
PiperOrigin-RevId: 745998273 Change-Id: I203af89332ace5d7a60fff3ec6ff4cc88c02340d
1 parent d3664b5 commit 51c489f

File tree

13 files changed

+357
-44
lines changed

13 files changed

+357
-44
lines changed

doc/changelog.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
=========
22
Changelog
33
=========
4+
5+
Upcoming version (not yet release)
6+
----------------------------------
7+
8+
MJX
9+
^^^
10+
- Added inverse dynamics.
11+
412
Version 3.3.1 (Apr 9, 2025)
513
----------------------------
614

doc/mjx.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ The following features are **fully supported** in MJX:
235235
- 1, 3, 4, 6 (1 is not supported with ``ELLIPTIC``)
236236
* - :ref:`Solver <mjtSolver>`
237237
- ``CG``, ``NEWTON``
238+
* - Dynamics
239+
- :ref:`Inverse <mj_inverse>`
238240
* - Fluid Model
239241
- :ref:`flInertia`
240242
* - :ref:`Tendons <tendon>`
@@ -262,8 +264,6 @@ The following features are **in development** and coming soon:
262264
(``BOX``, ``MESH``, ``HFIELD``) and ``ELLIPSOID``.
263265
* - :ref:`Integrator <mjtIntegrator>`
264266
- ``IMPLICIT``
265-
* - Dynamics
266-
- :ref:`Inverse <mj_inverse>`
267267
* - Fluid Model
268268
- :ref:`flEllipsoid`
269269
* - :ref:`Sensors <mjtSensor>`

mjx/mujoco/mjx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint:disable=g-importing-member
1818
from mujoco.mjx._src.collision_driver import collision
1919
from mujoco.mjx._src.constraint import make_constraint
20+
from mujoco.mjx._src.derivative import deriv_smooth_vel
2021
from mujoco.mjx._src.forward import euler
2122
from mujoco.mjx._src.forward import forward
2223
from mujoco.mjx._src.forward import fwd_acceleration
@@ -26,6 +27,7 @@
2627
from mujoco.mjx._src.forward import implicit
2728
from mujoco.mjx._src.forward import rungekutta4
2829
from mujoco.mjx._src.forward import step
30+
from mujoco.mjx._src.inverse import inverse
2931
from mujoco.mjx._src.io import get_data
3032
from mujoco.mjx._src.io import get_data_into
3133
from mujoco.mjx._src.io import make_data

mjx/mujoco/mjx/_src/derivative.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Derivative functions."""
16+
17+
from typing import Optional
18+
19+
import jax
20+
from jax import numpy as jp
21+
# pylint: disable=g-importing-member
22+
from mujoco.mjx._src.types import BiasType
23+
from mujoco.mjx._src.types import Data
24+
from mujoco.mjx._src.types import DisableBit
25+
from mujoco.mjx._src.types import DynType
26+
from mujoco.mjx._src.types import GainType
27+
from mujoco.mjx._src.types import Model
28+
29+
30+
def deriv_smooth_vel(m: Model, d: Data) -> Optional[jax.Array]:
31+
"""Analytical derivative of smooth forces w.r.t velocities."""
32+
33+
qderiv = None
34+
35+
# qDeriv += d qfrc_actuator / d qvel
36+
if not m.opt.disableflags & DisableBit.ACTUATION:
37+
affine_bias = m.actuator_biastype == BiasType.AFFINE
38+
bias_vel = m.actuator_biasprm[:, 2] * affine_bias
39+
affine_gain = m.actuator_gaintype == GainType.AFFINE
40+
gain_vel = m.actuator_gainprm[:, 2] * affine_gain
41+
ctrl = d.ctrl.at[m.actuator_dyntype != DynType.NONE].set(d.act)
42+
vel = bias_vel + gain_vel * ctrl
43+
qderiv = d.actuator_moment.T @ jax.vmap(jp.multiply)(d.actuator_moment, vel)
44+
45+
# qDeriv += d qfrc_passive / d qvel
46+
if not m.opt.disableflags & DisableBit.PASSIVE:
47+
if qderiv is None:
48+
qderiv = -jp.diag(m.dof_damping)
49+
else:
50+
qderiv -= jp.diag(m.dof_damping)
51+
if m.ntendon:
52+
qderiv -= d.ten_J.T @ jp.diag(m.tendon_damping) @ d.ten_J
53+
# TODO(robotics-simulation): fluid drag model
54+
if m.opt.has_fluid_params:
55+
raise NotImplementedError('fluid drag not supported for implicitfast')
56+
57+
# TODO(team): rne derivative
58+
59+
return qderiv

mjx/mujoco/mjx/_src/forward.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import mujoco
2323
from mujoco.mjx._src import collision_driver
2424
from mujoco.mjx._src import constraint
25+
from mujoco.mjx._src import derivative
2526
from mujoco.mjx._src import math
2627
from mujoco.mjx._src import passive
2728
from mujoco.mjx._src import scan
@@ -392,29 +393,7 @@ def f(carry, x):
392393
def implicit(m: Model, d: Data) -> Data:
393394
"""Integrates fully implicit in velocity."""
394395

395-
qderiv = None
396-
397-
# qDeriv += d qfrc_actuator / d qvel
398-
if not m.opt.disableflags & DisableBit.ACTUATION:
399-
affine_bias = m.actuator_biastype == BiasType.AFFINE
400-
bias_vel = m.actuator_biasprm[:, 2] * affine_bias
401-
affine_gain = m.actuator_gaintype == GainType.AFFINE
402-
gain_vel = m.actuator_gainprm[:, 2] * affine_gain
403-
ctrl = d.ctrl.at[m.actuator_dyntype != DynType.NONE].set(d.act)
404-
vel = bias_vel + gain_vel * ctrl
405-
qderiv = d.actuator_moment.T @ jp.diag(vel) @ d.actuator_moment
406-
407-
# qDeriv += d qfrc_passive / d qvel
408-
if not m.opt.disableflags & DisableBit.PASSIVE:
409-
if qderiv is None:
410-
qderiv = -jp.diag(m.dof_damping)
411-
else:
412-
qderiv -= jp.diag(m.dof_damping)
413-
if m.ntendon:
414-
qderiv -= d.ten_J.T @ jp.diag(m.tendon_damping) @ d.ten_J
415-
# TODO(robotics-simulation): fluid drag model
416-
if m.opt.has_fluid_params:
417-
raise NotImplementedError('fluid drag not supported for implicitfast')
396+
qderiv = derivative.deriv_smooth_vel(m, d)
418397

419398
qacc = d.qacc
420399
if qderiv is not None:

mjx/mujoco/mjx/_src/inverse.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2025 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Inverse dynamics functions."""
16+
17+
from jax import numpy as jp
18+
from mujoco.mjx._src import derivative
19+
from mujoco.mjx._src import forward
20+
from mujoco.mjx._src import sensor
21+
from mujoco.mjx._src import smooth
22+
from mujoco.mjx._src import solver
23+
from mujoco.mjx._src import support
24+
# pylint: disable=g-importing-member
25+
from mujoco.mjx._src.types import Data
26+
from mujoco.mjx._src.types import DisableBit
27+
from mujoco.mjx._src.types import EnableBit
28+
from mujoco.mjx._src.types import IntegratorType
29+
from mujoco.mjx._src.types import Model
30+
31+
32+
def discrete_acc(m: Model, d: Data) -> Data:
33+
"""Convert discrete-time qacc to continuous-time qacc."""
34+
35+
if m.opt.integrator == IntegratorType.RK4:
36+
raise RuntimeError(
37+
'discrete inverse dynamics is not supported by RK4 integrator'
38+
)
39+
elif m.opt.integrator == IntegratorType.EULER:
40+
dsbl_eulerdamp = m.opt.disableflags & DisableBit.EULERDAMP
41+
no_dof_damping = (m.dof_damping == 0).all()
42+
if dsbl_eulerdamp or no_dof_damping:
43+
return d
44+
45+
# set qfrc = (M + h*diag(B)) * qacc
46+
qfrc = support.mul_m(m, d, d.qacc)
47+
qfrc += m.opt.timestep * m.dof_damping * d.qacc
48+
elif m.opt.integrator == IntegratorType.IMPLICITFAST:
49+
qm = support.full_m(m, d)
50+
51+
# compute analytical derivative qDeriv; skip rne derivative
52+
qderiv = derivative.deriv_smooth_vel(m, d)
53+
if qderiv is not None:
54+
# M = M - dt*qDeriv
55+
qm -= m.opt.timestep * qderiv
56+
57+
# set qfrc = (M - dt*qDeriv) * qacc
58+
qfrc = qm @ d.qacc
59+
else:
60+
raise NotImplementedError(f'integrator {m.opt.integrator} not implemented.')
61+
62+
# solve for qacc: qfrc = M * qacc
63+
qacc = smooth.solve_m(m, d, qfrc)
64+
65+
return d.replace(qacc=qacc)
66+
67+
68+
def inv_constraint(m: Model, d: Data) -> Data:
69+
"""Inverse constraint solver."""
70+
71+
# no constraints
72+
if d.efc_J.size == 0:
73+
return d.replace(qfrc_constraint=jp.zeros(m.nv))
74+
75+
# update
76+
ctx = solver.Context.create(m, d, grad=False)
77+
78+
return d.replace(
79+
qfrc_constraint=ctx.qfrc_constraint,
80+
efc_force=ctx.efc_force,
81+
)
82+
83+
84+
def inverse(m: Model, d: Data) -> Data:
85+
"""Inverse dynamics."""
86+
d = forward.fwd_position(m, d)
87+
d = sensor.sensor_pos(m, d)
88+
d = forward.fwd_velocity(m, d)
89+
d = sensor.sensor_vel(m, d)
90+
91+
qacc = d.qacc
92+
if m.opt.enableflags & EnableBit.INVDISCRETE:
93+
d = discrete_acc(m, d)
94+
95+
d = inv_constraint(m, d)
96+
d = smooth.rne(m, d, flg_acc=True)
97+
d = sensor.sensor_acc(m, d)
98+
99+
qfrc_inverse = (
100+
d.qfrc_bias + m.dof_armature * d.qacc - d.qfrc_passive - d.qfrc_constraint
101+
)
102+
103+
if m.opt.enableflags & EnableBit.INVDISCRETE:
104+
return d.replace(qfrc_inverse=qfrc_inverse, qacc=qacc)
105+
else:
106+
return d.replace(qfrc_inverse=qfrc_inverse)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2023 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for inverse dynamics functions."""
16+
from absl.testing import absltest
17+
from absl.testing import parameterized
18+
import jax
19+
from jax import numpy as jp
20+
import mujoco
21+
from mujoco import mjx
22+
from mujoco.mjx._src import support
23+
from mujoco.mjx._src import test_util
24+
import numpy as np
25+
26+
# tolerance for difference between MuJoCo and MJX calculations - mostly
27+
# due to float precision
28+
_TOLERANCE = 1e-5
29+
30+
31+
def _assert_eq(a, b, name, tol=_TOLERANCE):
32+
tol = tol * 10 # avoid test noise
33+
err_msg = f'mismatch: {name}'
34+
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
35+
36+
37+
class InverseTest(parameterized.TestCase):
38+
39+
@parameterized.parameters(
40+
(mujoco.mjtIntegrator.mjINT_EULER, False, False),
41+
(mujoco.mjtIntegrator.mjINT_EULER, False, True),
42+
(mujoco.mjtIntegrator.mjINT_EULER, True, False),
43+
(mujoco.mjtIntegrator.mjINT_EULER, True, True),
44+
(mujoco.mjtIntegrator.mjINT_IMPLICITFAST, False, False),
45+
(mujoco.mjtIntegrator.mjINT_IMPLICITFAST, True, False),
46+
)
47+
def test_forward_inverse_match(self, integrator, invdiscrete, eulerdamp):
48+
m = mujoco.MjModel.from_xml_string("""
49+
<mujoco>
50+
<option timestep=".005" gravity="-1 -1 -10"/>
51+
<worldbody>
52+
<geom type="plane" size="10 10 .001"/>
53+
<body pos="0 0 1">
54+
<geom type="sphere" size=".1" pos=".1 .2 .3"/>
55+
<joint name="jnt1" type="hinge" axis="0 1 0" stiffness=".25" damping=".125"/>
56+
<body pos="0 0 1">
57+
<geom type="sphere" size=".1" pos=".1 .2 .3"/>
58+
<joint name="jnt2" type="hinge" axis="0 1 0" stiffness=".6" damping=".3"/>
59+
</body>
60+
</body>
61+
</worldbody>
62+
<actuator>
63+
<motor joint="jnt1"/>
64+
</actuator>
65+
<equality>
66+
<joint joint1="jnt1" joint2="jnt2"/>
67+
</equality>
68+
</mujoco>
69+
""")
70+
m.opt.integrator = integrator
71+
if invdiscrete:
72+
m.opt.enableflags |= mujoco.mjtEnableBit.mjENBL_INVDISCRETE
73+
if not eulerdamp:
74+
m.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_EULERDAMP
75+
76+
d = mujoco.MjData(m)
77+
d.qvel = np.random.uniform(low=-0.01, high=0.01, size=d.qvel.shape)
78+
d.ctrl = np.random.uniform(low=-0.01, high=0.01, size=d.ctrl.shape)
79+
d.qfrc_applied = np.random.uniform(
80+
low=-0.01, high=0.01, size=d.qfrc_applied.shape
81+
)
82+
d.xfrc_applied = np.random.uniform(
83+
low=-0.01, high=0.01, size=d.xfrc_applied.shape
84+
)
85+
mujoco.mj_step(m, d, 100)
86+
87+
mx = mjx.put_model(m)
88+
dx = mjx.put_data(m, d)
89+
dx_next = mjx.step(mx, dx)
90+
qacc_fd = (dx_next.qvel - dx.qvel) / mx.opt.timestep
91+
92+
dx = mjx.forward(mx, dx)
93+
94+
if invdiscrete:
95+
dx = dx.replace(qacc=qacc_fd)
96+
97+
dxinv = mjx.inverse(mx, dx)
98+
99+
fwdinv0 = jp.linalg.norm(
100+
dxinv.qfrc_constraint - dx.qfrc_constraint, ord=np.inf
101+
)
102+
fwdinv1 = jp.linalg.norm(
103+
dxinv.qfrc_inverse
104+
- (
105+
dx.qfrc_applied + dx.qfrc_actuator + support.xfrc_accumulate(mx, dx)
106+
),
107+
ord=np.inf,
108+
)
109+
110+
self.assertLess(fwdinv0, 1.0e-3)
111+
self.assertLess(fwdinv1, 1.0e-3)
112+
_assert_eq(dxinv.qacc, dx.qacc, 'qacc')
113+
114+
def test_tendon_force_clamp(self):
115+
m = test_util.load_test_file('actuator/tendon_force_clamp.xml')
116+
d = mujoco.MjData(m)
117+
mx = mjx.put_model(m)
118+
dx = mjx.put_data(m, d)
119+
120+
dx = dx.replace(ctrl=jp.array([1.0, 1.0, 1.0, -1.0, 1.0, -20.0, 5.0, -5.0]))
121+
dx = mjx.forward(mx, dx)
122+
123+
_assert_eq(
124+
dx.actuator_force,
125+
jp.array([1.0, 1.0, 1.0, -1.0, 1.0, -10.0, 5.0, -5.0]),
126+
'actuator_force',
127+
)
128+
129+
130+
if __name__ == '__main__':
131+
absltest.main()

0 commit comments

Comments
 (0)