Skip to content

Commit d3ff929

Browse files
Google DeepMindcopybara-github
authored andcommitted
Add ten_length as MJX public field and warp tendon function call to MJX.
PiperOrigin-RevId: 794813523 Change-Id: Ic5684788eca5a3f4fd86193ec5d035d40b3aad35
1 parent d7e925b commit d3ff929

File tree

15 files changed

+28
-258
lines changed

15 files changed

+28
-258
lines changed

doc/changelog.rst

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,6 @@
22
Changelog
33
=========
44

5-
Upcoming version (not yet released)
6-
-----------------------------------
7-
8-
MJX
9-
^^^
10-
- Promote ``ten_length`` to the public MJX API. Add Warp support for ``mjx.tendon``.
11-
12-
.. admonition:: Breaking API changes
13-
:class: attention
14-
15-
- ``ten_length`` was moved from ``mjx.Data._impl.ten_length`` to a public field ``mjx.Data.ten_length``.
16-
17-
185
Version 3.3.5 (August 8, 2025)
196
-----------------------------------
207

mjx/mujoco/mjx/_src/constraint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def rows(
322322

323323
inv1, inv2 = m.tendon_invweight0[obj1id], m.tendon_invweight0[obj2id]
324324
jac1, jac2 = d._impl.ten_J[obj1id], d._impl.ten_J[obj2id]
325-
pos1 = d.ten_length[obj1id] - m.tendon_length0[obj1id]
326-
pos2 = d.ten_length[obj2id] - m.tendon_length0[obj2id]
325+
pos1 = d._impl.ten_length[obj1id] - m.tendon_length0[obj1id]
326+
pos2 = d._impl.ten_length[obj2id] - m.tendon_length0[obj2id]
327327
invweight = inv1 + inv2 * (obj2id > -1)
328328

329329
return rows(
@@ -436,7 +436,7 @@ def _efc_limit_tendon(m: Model, d: Data) -> Optional[_Efc]:
436436
length, j, range_, margin, invweight, solref, solimp = jax.tree_util.tree_map(
437437
lambda x: x[tendon_id],
438438
(
439-
d.ten_length,
439+
d._impl.ten_length,
440440
d._impl.ten_J,
441441
m.tendon_range,
442442
m.tendon_margin,

mjx/mujoco/mjx/_src/io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,6 @@ def _make_data_public_fields(m: types.Model) -> Dict[str, Any]:
576576
'qfrc_constraint': (m.nv, float_),
577577
'qfrc_inverse': (m.nv, float_),
578578
'cvel': (m.nbody, 6, float_),
579-
'ten_length': (m.ntendon, float_),
580579
}
581580
zero_fields = {
582581
k: np.zeros(v[:-1], dtype=v[-1]) for k, v in zero_fields.items()
@@ -638,6 +637,7 @@ def _make_data_jax(
638637
'ten_wrapadr': (m.ntendon, np.int32),
639638
'ten_wrapnum': (m.ntendon, np.int32),
640639
'ten_J': (m.ntendon, m.nv, float_),
640+
'ten_length': (m.ntendon, float_),
641641
'wrap_obj': (m.nwrap, 2, np.int32),
642642
'wrap_xpos': (m.nwrap, 6, float_),
643643
'actuator_length': (m.nu, float_),
@@ -742,6 +742,7 @@ def get(m, name: str):
742742
'ten_J_rowadr': (m.ntendon, np.int32),
743743
'ten_J_colind': (m.ntendon, m.nv, np.int32),
744744
'ten_J': (m.ntendon, m.nv, float_),
745+
'ten_length': (m.ntendon, float_),
745746
'ten_wrapadr': (m.ntendon, np.int32),
746747
'ten_wrapnum': (m.ntendon, np.int32),
747748
'wrap_obj': (m.nwrap, 2, np.int32),

mjx/mujoco/mjx/_src/io_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def test_put_data(self, impl: str):
478478
np.testing.assert_allclose(dx.site_xmat.reshape((1, 9)), d.site_xmat)
479479

480480
# tendon data is correct
481-
np.testing.assert_allclose(dx.ten_length, d.ten_length)
481+
np.testing.assert_allclose(dx._impl.ten_length, d.ten_length)
482482
np.testing.assert_equal(dx._impl.ten_wrapadr, np.zeros((1,)))
483483
np.testing.assert_equal(dx._impl.ten_wrapnum, np.zeros((1,)))
484484
np.testing.assert_equal(dx._impl.wrap_obj, np.zeros((2, 2)))

mjx/mujoco/mjx/_src/passive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def fn(jnt_typs, stiffness, qpos_spring, qpos):
7575
qfrc -= m.dof_damping * d.qvel
7676

7777
# tendon-level spring-dampers
78-
below, above = m.tendon_lengthspring.T - d.ten_length
78+
below, above = m.tendon_lengthspring.T - d._impl.ten_length
7979
frc_spring = jp.where(below > 0, m.tendon_stiffness * below, 0)
8080
frc_spring = jp.where(above < 0, m.tendon_stiffness * above, frc_spring)
8181
frc_damper = -m.tendon_damping * d._impl.ten_velocity

mjx/mujoco/mjx/_src/sensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from mujoco.mjx._src import ray
2323
from mujoco.mjx._src import smooth
2424
from mujoco.mjx._src import support
25+
from mujoco.mjx._src.types import Impl
2526
from mujoco.mjx._src.types import Data
2627
from mujoco.mjx._src.types import DataJAX
2728
from mujoco.mjx._src.types import DisableBit
@@ -173,7 +174,7 @@ def _cam_project(
173174
elif sensor_type == SensorType.JOINTPOS:
174175
sensor = d.qpos[m.jnt_qposadr[objid]]
175176
elif sensor_type == SensorType.TENDONPOS:
176-
sensor = d.ten_length[objid]
177+
sensor = d._impl.ten_length[objid]
177178
elif sensor_type == SensorType.ACTUATORPOS:
178179
sensor = d._impl.actuator_length[objid]
179180
elif sensor_type == SensorType.BALLQUAT:

mjx/mujoco/mjx/_src/smooth.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,18 @@
2626
from mujoco.mjx._src.types import DataJAX
2727
from mujoco.mjx._src.types import DisableBit
2828
from mujoco.mjx._src.types import EqType
29-
from mujoco.mjx._src.types import Impl
3029
from mujoco.mjx._src.types import JointType
3130
from mujoco.mjx._src.types import Model
3231
from mujoco.mjx._src.types import ModelJAX
3332
from mujoco.mjx._src.types import ObjType
3433
from mujoco.mjx._src.types import TrnType
3534
from mujoco.mjx._src.types import WrapType
3635
# pylint: enable=g-importing-member
37-
import mujoco.mjx.warp as mjxw
3836
import numpy as np
3937

4038

4139
def kinematics(m: Model, d: Data) -> Data:
4240
"""Converts position/velocity from generalized coordinates to maximal."""
43-
if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED:
44-
from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error
45-
return mjxw_smooth.kinematics(m, d)
46-
4741
def fn(carry, jnt_typs, jnt_pos, jnt_axis, qpos, qpos0, pos, quat):
4842
# calculate joint anchors, axes, body pos and quat in global frame
4943
# also normalize qpos while we're at it
@@ -850,10 +844,6 @@ def _forward(carry, cfrc_ext, cinert, cvel, body_dofadr, body_dofnum):
850844

851845
def tendon(m: Model, d: Data) -> Data:
852846
"""Computes tendon lengths and moments."""
853-
if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED:
854-
from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error
855-
return mjxw_smooth.tendon(m, d)
856-
857847
if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX):
858848
raise ValueError('tendon requires JAX backend implementation.')
859849

@@ -1101,7 +1091,7 @@ def _distance(p0, p1):
11011091

11021092
# assemble length and moment
11031093
ten_length = (
1104-
jp.zeros_like(d.ten_length).at[tendon_id_jnt].set(length_jnt)
1094+
jp.zeros_like(d._impl.ten_length).at[tendon_id_jnt].set(length_jnt)
11051095
)
11061096
ten_length = ten_length.at[tendon_id_site].add(length_site)
11071097
ten_length = ten_length.at[tendon_id_geom].add(length_geom)
@@ -1171,7 +1161,7 @@ def _distance(p0, p1):
11711161
).reshape((m.nwrap, 2))
11721162

11731163
return d.tree_replace({
1174-
'ten_length': ten_length,
1164+
'_impl.ten_length': ten_length,
11751165
'_impl.ten_J': ten_moment,
11761166
'_impl.ten_wrapadr': jp.array(ten_wrapadr, dtype=int),
11771167
'_impl.ten_wrapnum': jp.array(ten_wrapnum, dtype=int),
@@ -1273,7 +1263,7 @@ def fn(
12731263
wrench = jp.concatenate((frame_xmat @ gear[:3], frame_xmat @ gear[3:]))
12741264
moment = jac @ wrench
12751265
elif trntype == TrnType.TENDON:
1276-
length = d.ten_length[trnid[0]] * gear[:1]
1266+
length = d._impl.ten_length[trnid[0]] * gear[:1]
12771267
moment = d._impl.ten_J[trnid[0]] * gear[0]
12781268
else:
12791269
raise RuntimeError(f'unrecognized trntype: {TrnType(trntype)}')

mjx/mujoco/mjx/_src/smooth_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_smooth(self):
122122
# tendon
123123
dx = jax.jit(mjx.tendon)(mx, mjx.put_data(m, d))
124124
_assert_attr_eq(d, dx._impl, 'ten_J')
125-
_assert_attr_eq(d, dx, 'ten_length')
125+
_assert_attr_eq(d, dx._impl, 'ten_length')
126126
# transmission
127127
dx = jax.jit(mjx.transmission)(mx, dx)
128128
_assert_attr_eq(d, dx._impl, 'actuator_length')
@@ -394,7 +394,7 @@ def test_tendon(self, filename):
394394
mujoco.mj_forward(m, d)
395395
dx = jax.jit(mjx.forward)(mx, dx)
396396

397-
_assert_eq(d.ten_length, dx.ten_length, 'ten_length')
397+
_assert_eq(d.ten_length, dx._impl.ten_length, 'ten_length')
398398
_assert_eq(d.ten_J, dx._impl.ten_J, 'ten_J')
399399
_assert_eq(d.ten_wrapnum, dx._impl.ten_wrapnum, 'ten_wrapnum')
400400
_assert_eq(d.ten_wrapadr, dx._impl.ten_wrapadr, 'ten_wrapadr')

mjx/mujoco/mjx/_src/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,7 @@ class DataC(PyTreeNode):
998998
ten_J_rowadr: jax.Array # pylint:disable=invalid-name
999999
ten_J_colind: jax.Array # pylint:disable=invalid-name
10001000
ten_J: jax.Array # pylint:disable=invalid-name
1001+
ten_length: jax.Array
10011002
wrap_obj: jax.Array
10021003
wrap_xpos: jax.Array
10031004
actuator_length: jax.Array
@@ -1069,6 +1070,7 @@ class DataJAX(PyTreeNode):
10691070
ten_wrapadr: jax.Array
10701071
ten_wrapnum: jax.Array
10711072
ten_J: jax.Array # pylint:disable=invalid-name
1073+
ten_length: jax.Array
10721074
wrap_obj: jax.Array
10731075
wrap_xpos: jax.Array
10741076
actuator_length: jax.Array
@@ -1131,7 +1133,6 @@ class Data(PyTreeNode):
11311133
ximat: jax.Array
11321134
xanchor: jax.Array
11331135
xaxis: jax.Array
1134-
ten_length: jax.Array
11351136
geom_xpos: jax.Array
11361137
geom_xmat: jax.Array
11371138
site_xpos: jax.Array

mjx/mujoco/mjx/warp/forward.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data):
11131113
'ten_Jdot': d._impl.ten_Jdot.shape,
11141114
'ten_actfrc': d._impl.ten_actfrc.shape,
11151115
'ten_bias_coef': d._impl.ten_bias_coef.shape,
1116-
'ten_length': d.ten_length.shape,
1116+
'ten_length': d._impl.ten_length.shape,
11171117
'ten_velocity': d._impl.ten_velocity.shape,
11181118
'ten_wrapadr': d._impl.ten_wrapadr.shape,
11191119
'ten_wrapnum': d._impl.ten_wrapnum.shape,
@@ -1777,7 +1777,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data):
17771777
d._impl.ten_Jdot,
17781778
d._impl.ten_actfrc,
17791779
d._impl.ten_bias_coef,
1780-
d.ten_length,
1780+
d._impl.ten_length,
17811781
d._impl.ten_velocity,
17821782
d._impl.ten_wrapadr,
17831783
d._impl.ten_wrapnum,
@@ -1956,7 +1956,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data):
19561956
'_impl.ten_Jdot': out[99],
19571957
'_impl.ten_actfrc': out[100],
19581958
'_impl.ten_bias_coef': out[101],
1959-
'ten_length': out[102],
1959+
'_impl.ten_length': out[102],
19601960
'_impl.ten_velocity': out[103],
19611961
'_impl.ten_wrapadr': out[104],
19621962
'_impl.ten_wrapnum': out[105],
@@ -3176,7 +3176,7 @@ def _step_jax_impl(m: types.Model, d: types.Data):
31763176
'ten_Jdot': d._impl.ten_Jdot.shape,
31773177
'ten_actfrc': d._impl.ten_actfrc.shape,
31783178
'ten_bias_coef': d._impl.ten_bias_coef.shape,
3179-
'ten_length': d.ten_length.shape,
3179+
'ten_length': d._impl.ten_length.shape,
31803180
'ten_velocity': d._impl.ten_velocity.shape,
31813181
'ten_wrapadr': d._impl.ten_wrapadr.shape,
31823182
'ten_wrapnum': d._impl.ten_wrapnum.shape,
@@ -3866,7 +3866,7 @@ def _step_jax_impl(m: types.Model, d: types.Data):
38663866
d._impl.ten_Jdot,
38673867
d._impl.ten_actfrc,
38683868
d._impl.ten_bias_coef,
3869-
d.ten_length,
3869+
d._impl.ten_length,
38703870
d._impl.ten_velocity,
38713871
d._impl.ten_wrapadr,
38723872
d._impl.ten_wrapnum,
@@ -4057,7 +4057,7 @@ def _step_jax_impl(m: types.Model, d: types.Data):
40574057
'_impl.ten_Jdot': out[111],
40584058
'_impl.ten_actfrc': out[112],
40594059
'_impl.ten_bias_coef': out[113],
4060-
'ten_length': out[114],
4060+
'_impl.ten_length': out[114],
40614061
'_impl.ten_velocity': out[115],
40624062
'_impl.ten_wrapadr': out[116],
40634063
'_impl.ten_wrapnum': out[117],

0 commit comments

Comments
 (0)