Skip to content

Commit 1763fa5

Browse files
btabacopybara-github
authored andcommitted
Add ten_length as MJX public field and warp tendon function call to MJX.
PiperOrigin-RevId: 794869371 Change-Id: I733da169b0092e3165478324697d7aa3569eca8b
1 parent d3ff929 commit 1763fa5

File tree

15 files changed

+258
-28
lines changed

15 files changed

+258
-28
lines changed

doc/changelog.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,19 @@
22
Changelog
33
=========
44

5+
Upcoming version (not yet released)
6+
-----------------------------------
7+
8+
MJX
9+
^^^
10+
- Promote ``ten_length`` to the public MJX API. Add Warp support for ``mjx.tendon``.
11+
12+
.. admonition:: Breaking API changes
13+
:class: attention
14+
15+
- ``ten_length`` was moved from ``mjx.Data._impl.ten_length`` to a public field ``mjx.Data.ten_length``.
16+
17+
518
Version 3.3.5 (August 8, 2025)
619
-----------------------------------
720

mjx/mujoco/mjx/_src/constraint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def rows(
322322

323323
inv1, inv2 = m.tendon_invweight0[obj1id], m.tendon_invweight0[obj2id]
324324
jac1, jac2 = d._impl.ten_J[obj1id], d._impl.ten_J[obj2id]
325-
pos1 = d._impl.ten_length[obj1id] - m.tendon_length0[obj1id]
326-
pos2 = d._impl.ten_length[obj2id] - m.tendon_length0[obj2id]
325+
pos1 = d.ten_length[obj1id] - m.tendon_length0[obj1id]
326+
pos2 = d.ten_length[obj2id] - m.tendon_length0[obj2id]
327327
invweight = inv1 + inv2 * (obj2id > -1)
328328

329329
return rows(
@@ -436,7 +436,7 @@ def _efc_limit_tendon(m: Model, d: Data) -> Optional[_Efc]:
436436
length, j, range_, margin, invweight, solref, solimp = jax.tree_util.tree_map(
437437
lambda x: x[tendon_id],
438438
(
439-
d._impl.ten_length,
439+
d.ten_length,
440440
d._impl.ten_J,
441441
m.tendon_range,
442442
m.tendon_margin,

mjx/mujoco/mjx/_src/io.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ def _make_data_public_fields(m: types.Model) -> Dict[str, Any]:
576576
'qfrc_constraint': (m.nv, float_),
577577
'qfrc_inverse': (m.nv, float_),
578578
'cvel': (m.nbody, 6, float_),
579+
'ten_length': (m.ntendon, float_),
579580
}
580581
zero_fields = {
581582
k: np.zeros(v[:-1], dtype=v[-1]) for k, v in zero_fields.items()
@@ -637,7 +638,6 @@ def _make_data_jax(
637638
'ten_wrapadr': (m.ntendon, np.int32),
638639
'ten_wrapnum': (m.ntendon, np.int32),
639640
'ten_J': (m.ntendon, m.nv, float_),
640-
'ten_length': (m.ntendon, float_),
641641
'wrap_obj': (m.nwrap, 2, np.int32),
642642
'wrap_xpos': (m.nwrap, 6, float_),
643643
'actuator_length': (m.nu, float_),
@@ -742,7 +742,6 @@ def get(m, name: str):
742742
'ten_J_rowadr': (m.ntendon, np.int32),
743743
'ten_J_colind': (m.ntendon, m.nv, np.int32),
744744
'ten_J': (m.ntendon, m.nv, float_),
745-
'ten_length': (m.ntendon, float_),
746745
'ten_wrapadr': (m.ntendon, np.int32),
747746
'ten_wrapnum': (m.ntendon, np.int32),
748747
'wrap_obj': (m.nwrap, 2, np.int32),

mjx/mujoco/mjx/_src/io_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def test_put_data(self, impl: str):
478478
np.testing.assert_allclose(dx.site_xmat.reshape((1, 9)), d.site_xmat)
479479

480480
# tendon data is correct
481-
np.testing.assert_allclose(dx._impl.ten_length, d.ten_length)
481+
np.testing.assert_allclose(dx.ten_length, d.ten_length)
482482
np.testing.assert_equal(dx._impl.ten_wrapadr, np.zeros((1,)))
483483
np.testing.assert_equal(dx._impl.ten_wrapnum, np.zeros((1,)))
484484
np.testing.assert_equal(dx._impl.wrap_obj, np.zeros((2, 2)))

mjx/mujoco/mjx/_src/passive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def fn(jnt_typs, stiffness, qpos_spring, qpos):
7575
qfrc -= m.dof_damping * d.qvel
7676

7777
# tendon-level spring-dampers
78-
below, above = m.tendon_lengthspring.T - d._impl.ten_length
78+
below, above = m.tendon_lengthspring.T - d.ten_length
7979
frc_spring = jp.where(below > 0, m.tendon_stiffness * below, 0)
8080
frc_spring = jp.where(above < 0, m.tendon_stiffness * above, frc_spring)
8181
frc_damper = -m.tendon_damping * d._impl.ten_velocity

mjx/mujoco/mjx/_src/sensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from mujoco.mjx._src import ray
2323
from mujoco.mjx._src import smooth
2424
from mujoco.mjx._src import support
25-
from mujoco.mjx._src.types import Impl
2625
from mujoco.mjx._src.types import Data
2726
from mujoco.mjx._src.types import DataJAX
2827
from mujoco.mjx._src.types import DisableBit
@@ -174,7 +173,7 @@ def _cam_project(
174173
elif sensor_type == SensorType.JOINTPOS:
175174
sensor = d.qpos[m.jnt_qposadr[objid]]
176175
elif sensor_type == SensorType.TENDONPOS:
177-
sensor = d._impl.ten_length[objid]
176+
sensor = d.ten_length[objid]
178177
elif sensor_type == SensorType.ACTUATORPOS:
179178
sensor = d._impl.actuator_length[objid]
180179
elif sensor_type == SensorType.BALLQUAT:

mjx/mujoco/mjx/_src/smooth.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,24 @@
2626
from mujoco.mjx._src.types import DataJAX
2727
from mujoco.mjx._src.types import DisableBit
2828
from mujoco.mjx._src.types import EqType
29+
from mujoco.mjx._src.types import Impl
2930
from mujoco.mjx._src.types import JointType
3031
from mujoco.mjx._src.types import Model
3132
from mujoco.mjx._src.types import ModelJAX
3233
from mujoco.mjx._src.types import ObjType
3334
from mujoco.mjx._src.types import TrnType
3435
from mujoco.mjx._src.types import WrapType
3536
# pylint: enable=g-importing-member
37+
import mujoco.mjx.warp as mjxw
3638
import numpy as np
3739

3840

3941
def kinematics(m: Model, d: Data) -> Data:
4042
"""Converts position/velocity from generalized coordinates to maximal."""
43+
if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED:
44+
from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error
45+
return mjxw_smooth.kinematics(m, d)
46+
4147
def fn(carry, jnt_typs, jnt_pos, jnt_axis, qpos, qpos0, pos, quat):
4248
# calculate joint anchors, axes, body pos and quat in global frame
4349
# also normalize qpos while we're at it
@@ -844,6 +850,10 @@ def _forward(carry, cfrc_ext, cinert, cvel, body_dofadr, body_dofnum):
844850

845851
def tendon(m: Model, d: Data) -> Data:
846852
"""Computes tendon lengths and moments."""
853+
if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED:
854+
from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error
855+
return mjxw_smooth.tendon(m, d)
856+
847857
if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX):
848858
raise ValueError('tendon requires JAX backend implementation.')
849859

@@ -1091,7 +1101,7 @@ def _distance(p0, p1):
10911101

10921102
# assemble length and moment
10931103
ten_length = (
1094-
jp.zeros_like(d._impl.ten_length).at[tendon_id_jnt].set(length_jnt)
1104+
jp.zeros_like(d.ten_length).at[tendon_id_jnt].set(length_jnt)
10951105
)
10961106
ten_length = ten_length.at[tendon_id_site].add(length_site)
10971107
ten_length = ten_length.at[tendon_id_geom].add(length_geom)
@@ -1161,7 +1171,7 @@ def _distance(p0, p1):
11611171
).reshape((m.nwrap, 2))
11621172

11631173
return d.tree_replace({
1164-
'_impl.ten_length': ten_length,
1174+
'ten_length': ten_length,
11651175
'_impl.ten_J': ten_moment,
11661176
'_impl.ten_wrapadr': jp.array(ten_wrapadr, dtype=int),
11671177
'_impl.ten_wrapnum': jp.array(ten_wrapnum, dtype=int),
@@ -1263,7 +1273,7 @@ def fn(
12631273
wrench = jp.concatenate((frame_xmat @ gear[:3], frame_xmat @ gear[3:]))
12641274
moment = jac @ wrench
12651275
elif trntype == TrnType.TENDON:
1266-
length = d._impl.ten_length[trnid[0]] * gear[:1]
1276+
length = d.ten_length[trnid[0]] * gear[:1]
12671277
moment = d._impl.ten_J[trnid[0]] * gear[0]
12681278
else:
12691279
raise RuntimeError(f'unrecognized trntype: {TrnType(trntype)}')

mjx/mujoco/mjx/_src/smooth_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_smooth(self):
122122
# tendon
123123
dx = jax.jit(mjx.tendon)(mx, mjx.put_data(m, d))
124124
_assert_attr_eq(d, dx._impl, 'ten_J')
125-
_assert_attr_eq(d, dx._impl, 'ten_length')
125+
_assert_attr_eq(d, dx, 'ten_length')
126126
# transmission
127127
dx = jax.jit(mjx.transmission)(mx, dx)
128128
_assert_attr_eq(d, dx._impl, 'actuator_length')
@@ -394,7 +394,7 @@ def test_tendon(self, filename):
394394
mujoco.mj_forward(m, d)
395395
dx = jax.jit(mjx.forward)(mx, dx)
396396

397-
_assert_eq(d.ten_length, dx._impl.ten_length, 'ten_length')
397+
_assert_eq(d.ten_length, dx.ten_length, 'ten_length')
398398
_assert_eq(d.ten_J, dx._impl.ten_J, 'ten_J')
399399
_assert_eq(d.ten_wrapnum, dx._impl.ten_wrapnum, 'ten_wrapnum')
400400
_assert_eq(d.ten_wrapadr, dx._impl.ten_wrapadr, 'ten_wrapadr')

mjx/mujoco/mjx/_src/types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,6 @@ class DataC(PyTreeNode):
998998
ten_J_rowadr: jax.Array # pylint:disable=invalid-name
999999
ten_J_colind: jax.Array # pylint:disable=invalid-name
10001000
ten_J: jax.Array # pylint:disable=invalid-name
1001-
ten_length: jax.Array
10021001
wrap_obj: jax.Array
10031002
wrap_xpos: jax.Array
10041003
actuator_length: jax.Array
@@ -1070,7 +1069,6 @@ class DataJAX(PyTreeNode):
10701069
ten_wrapadr: jax.Array
10711070
ten_wrapnum: jax.Array
10721071
ten_J: jax.Array # pylint:disable=invalid-name
1073-
ten_length: jax.Array
10741072
wrap_obj: jax.Array
10751073
wrap_xpos: jax.Array
10761074
actuator_length: jax.Array
@@ -1133,6 +1131,7 @@ class Data(PyTreeNode):
11331131
ximat: jax.Array
11341132
xanchor: jax.Array
11351133
xaxis: jax.Array
1134+
ten_length: jax.Array
11361135
geom_xpos: jax.Array
11371136
geom_xmat: jax.Array
11381137
site_xpos: jax.Array

mjx/mujoco/mjx/warp/forward.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data):
11131113
'ten_Jdot': d._impl.ten_Jdot.shape,
11141114
'ten_actfrc': d._impl.ten_actfrc.shape,
11151115
'ten_bias_coef': d._impl.ten_bias_coef.shape,
1116-
'ten_length': d._impl.ten_length.shape,
1116+
'ten_length': d.ten_length.shape,
11171117
'ten_velocity': d._impl.ten_velocity.shape,
11181118
'ten_wrapadr': d._impl.ten_wrapadr.shape,
11191119
'ten_wrapnum': d._impl.ten_wrapnum.shape,
@@ -1777,7 +1777,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data):
17771777
d._impl.ten_Jdot,
17781778
d._impl.ten_actfrc,
17791779
d._impl.ten_bias_coef,
1780-
d._impl.ten_length,
1780+
d.ten_length,
17811781
d._impl.ten_velocity,
17821782
d._impl.ten_wrapadr,
17831783
d._impl.ten_wrapnum,
@@ -1956,7 +1956,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data):
19561956
'_impl.ten_Jdot': out[99],
19571957
'_impl.ten_actfrc': out[100],
19581958
'_impl.ten_bias_coef': out[101],
1959-
'_impl.ten_length': out[102],
1959+
'ten_length': out[102],
19601960
'_impl.ten_velocity': out[103],
19611961
'_impl.ten_wrapadr': out[104],
19621962
'_impl.ten_wrapnum': out[105],
@@ -3176,7 +3176,7 @@ def _step_jax_impl(m: types.Model, d: types.Data):
31763176
'ten_Jdot': d._impl.ten_Jdot.shape,
31773177
'ten_actfrc': d._impl.ten_actfrc.shape,
31783178
'ten_bias_coef': d._impl.ten_bias_coef.shape,
3179-
'ten_length': d._impl.ten_length.shape,
3179+
'ten_length': d.ten_length.shape,
31803180
'ten_velocity': d._impl.ten_velocity.shape,
31813181
'ten_wrapadr': d._impl.ten_wrapadr.shape,
31823182
'ten_wrapnum': d._impl.ten_wrapnum.shape,
@@ -3866,7 +3866,7 @@ def _step_jax_impl(m: types.Model, d: types.Data):
38663866
d._impl.ten_Jdot,
38673867
d._impl.ten_actfrc,
38683868
d._impl.ten_bias_coef,
3869-
d._impl.ten_length,
3869+
d.ten_length,
38703870
d._impl.ten_velocity,
38713871
d._impl.ten_wrapadr,
38723872
d._impl.ten_wrapnum,
@@ -4057,7 +4057,7 @@ def _step_jax_impl(m: types.Model, d: types.Data):
40574057
'_impl.ten_Jdot': out[111],
40584058
'_impl.ten_actfrc': out[112],
40594059
'_impl.ten_bias_coef': out[113],
4060-
'_impl.ten_length': out[114],
4060+
'ten_length': out[114],
40614061
'_impl.ten_velocity': out[115],
40624062
'_impl.ten_wrapadr': out[116],
40634063
'_impl.ten_wrapnum': out[117],

0 commit comments

Comments
 (0)