Skip to content

Commit 2ce5caa

Browse files
thowellcopybara-github
authored andcommitted
Update MJX rne_postconstraint with contributions from equality connect and equality weld constraints.
PiperOrigin-RevId: 777491801 Change-Id: Ib32dda4178def7a96d31e91787f3d0866a56cd76
1 parent 57b3d6d commit 2ce5caa

File tree

4 files changed

+208
-58
lines changed

4 files changed

+208
-58
lines changed

mjx/mujoco/mjx/_src/io.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -248,23 +248,6 @@ def _put_model_jax(
248248
if t == mujoco.mjtGeom.mjGEOM_MESH:
249249
mesh_geomid.add(g)
250250

251-
# check for unsupported sensor and equality constraint combinations
252-
sensor_rne_postconstraint = (
253-
np.any(m.sensor_type == types.SensorType.ACCELEROMETER)
254-
| np.any(m.sensor_type == types.SensorType.FORCE)
255-
| np.any(m.sensor_type == types.SensorType.TORQUE)
256-
| np.any(m.sensor_type == types.SensorType.FRAMELINACC)
257-
| np.any(m.sensor_type == types.SensorType.FRAMEANGACC)
258-
)
259-
eq_connect_weld = np.any(m.eq_type == types.EqType.CONNECT) | np.any(
260-
m.eq_type == types.EqType.WELD
261-
)
262-
if sensor_rne_postconstraint and eq_connect_weld:
263-
raise NotImplementedError(
264-
'rne_postconstraint not implemented with equality constraints:'
265-
' connect, weld.'
266-
)
267-
268251
for enum_field, enum_type, mj_type in (
269252
(m.actuator_biastype, types.BiasType, mujoco.mjtBias),
270253
(m.actuator_dyntype, types.DynType, mujoco.mjtDyn),

mjx/mujoco/mjx/_src/io_test.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -625,38 +625,6 @@ def test_contact_elliptic_condim1(self):
625625
with self.assertRaises(NotImplementedError):
626626
mjx.make_data(m)
627627

628-
@parameterized.product(
629-
sensor=['accelerometer', 'force', 'torque'], equality=['connect', 'weld']
630-
)
631-
def test_sensor_constraint_compatibility(self, sensor, equality):
632-
"""Test unsupported sensor and equality constraint combinations."""
633-
equality_constraint = f'{equality} body1="body1" body2="body2"'
634-
if equality == 'connect':
635-
equality_constraint += ' anchor="0 0 0"'
636-
m = mujoco.MjModel.from_xml_string(f"""
637-
<mujoco>
638-
<worldbody>
639-
<body name="body1">
640-
<freejoint/>
641-
<geom size="0.1"/>
642-
<site name="site1"/>
643-
</body>
644-
<body name="body2">
645-
<freejoint/>
646-
<geom size="0.1"/>
647-
</body>
648-
</worldbody>
649-
<equality>
650-
<{equality_constraint}/>
651-
</equality>
652-
<sensor>
653-
<{sensor} site="site1"/>
654-
</sensor>
655-
</mujoco>
656-
""")
657-
with self.assertRaises(NotImplementedError):
658-
mjx.put_model(m, impl='jax')
659-
660628
@parameterized.parameters(JacobianType.DENSE, JacobianType.SPARSE)
661629
def test_qm_mapm2m(self, jacobian):
662630
"""Test that qM is mapped to M."""

mjx/mujoco/mjx/_src/smooth.py

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from mujoco.mjx._src.types import JointType
3030
from mujoco.mjx._src.types import Model
3131
from mujoco.mjx._src.types import ModelJAX
32+
from mujoco.mjx._src.types import ObjType
3233
from mujoco.mjx._src.types import TrnType
3334
from mujoco.mjx._src.types import WrapType
3435
# pylint: enable=g-importing-member
@@ -659,11 +660,126 @@ def _contact_force_to_cfrc_ext(force, pos, frame, id1, id2, com1, com2):
659660
cfrc_contact.reshape((-1, 6))
660661
)
661662

662-
# TODO(taylorhowell): connect and weld constraints
663-
if np.any(m.eq_type == EqType.CONNECT):
664-
raise NotImplementedError('Connect constraints are not implemented.')
665-
if np.any(m.eq_type == EqType.WELD):
666-
raise NotImplementedError('Weld constraints are not implemented.')
663+
# cfrc_ext += connect, weld
664+
cfrc_ext_equality = []
665+
cfrc_ext_equality_adr = []
666+
667+
connect_id = m.eq_type == EqType.CONNECT
668+
nconnect = connect_id.sum()
669+
670+
if nconnect:
671+
cfrc_connect_force = d._impl.efc_force[: 3 * nconnect].reshape(
672+
(nconnect, 3)
673+
)
674+
675+
is_site = m.eq_objtype == ObjType.SITE
676+
body1id = np.copy(m.eq_obj1id)
677+
body2id = np.copy(m.eq_obj2id)
678+
pos1 = m.eq_data[:, :3]
679+
pos2 = m.eq_data[:, 3:6]
680+
681+
if m.nsite:
682+
body1id[is_site] = m.site_bodyid[m.eq_obj1id[is_site]]
683+
body2id[is_site] = m.site_bodyid[m.eq_obj2id[is_site]]
684+
pos1 = jp.where(is_site[:, None], m.site_pos[m.eq_obj1id], pos1)
685+
pos2 = jp.where(is_site[:, None], m.site_pos[m.eq_obj2id], pos2)
686+
687+
# body 1
688+
k1_connect = body1id[connect_id]
689+
k1_connect_mask = k1_connect != 0
690+
offset1_connect = pos1[connect_id]
691+
692+
pos1_connect = jax.vmap(lambda pnt, mat, vec: mat @ pnt + vec)(
693+
offset1_connect, d.xmat[k1_connect], d.xpos[k1_connect]
694+
)
695+
subtree_com1_connect = d.subtree_com[jp.array(m.body_rootid)[k1_connect]]
696+
cfrc_com1_connect = jax.vmap(
697+
lambda dif, frc, mask: mask * jp.concatenate([-jp.cross(dif, frc), frc])
698+
)(subtree_com1_connect - pos1_connect, cfrc_connect_force, k1_connect_mask)
699+
700+
# body 2
701+
k2_connect = body2id[connect_id]
702+
k2_connect_mask = -1 * (k2_connect != 0)
703+
offset2_connect = pos2[connect_id]
704+
705+
pos2_connect = jax.vmap(lambda pnt, mat, vec: mat @ pnt + vec)(
706+
offset2_connect, d.xmat[k2_connect], d.xpos[k2_connect]
707+
)
708+
subtree_com2_connect = d.subtree_com[jp.array(m.body_rootid)[k2_connect]]
709+
cfrc_com2_connect = jax.vmap(
710+
lambda dif, frc, mask: mask * jp.concatenate([-jp.cross(dif, frc), frc])
711+
)(subtree_com2_connect - pos2_connect, cfrc_connect_force, k2_connect_mask)
712+
713+
cfrc_ext_equality.append(jp.vstack([cfrc_com1_connect, cfrc_com2_connect]))
714+
cfrc_ext_equality_adr.append(jp.concatenate([k1_connect, k2_connect]))
715+
716+
weld_id = m.eq_type == EqType.WELD
717+
nweld = weld_id.sum()
718+
719+
if nweld:
720+
cfrc_weld = d._impl.efc_force[
721+
3 * nconnect : 3 * nconnect + 6 * nweld
722+
].reshape((nweld, 6))
723+
cfrc_weld_force = cfrc_weld[:, :3]
724+
cfrc_weld_torque = cfrc_weld[:, 3:]
725+
726+
is_site = m.eq_objtype == ObjType.SITE
727+
body1id = np.copy(m.eq_obj1id)
728+
body2id = np.copy(m.eq_obj2id)
729+
pos1 = m.eq_data[:, 3:6]
730+
pos2 = m.eq_data[:, :3]
731+
732+
if m.nsite:
733+
body1id[is_site] = m.site_bodyid[m.eq_obj1id[is_site]]
734+
body2id[is_site] = m.site_bodyid[m.eq_obj2id[is_site]]
735+
pos1 = jp.where(is_site[:, None], m.site_pos[m.eq_obj1id], pos1)
736+
pos2 = jp.where(is_site[:, None], m.site_pos[m.eq_obj2id], pos2)
737+
738+
# body 1
739+
k1_weld = body1id[weld_id]
740+
k1_weld_mask = k1_weld != 0
741+
offset1_weld = pos1[weld_id]
742+
743+
pos1_weld = jax.vmap(lambda pnt, mat, vec: mat @ pnt + vec)(
744+
offset1_weld, d.xmat[k1_weld], d.xpos[k1_weld]
745+
)
746+
subtree_com1_weld = d.subtree_com[jp.array(m.body_rootid)[k1_weld]]
747+
cfrc_com1_weld = jax.vmap(
748+
lambda dif, frc, trq, mask: mask
749+
* jp.concatenate([trq - jp.cross(dif, frc), frc])
750+
)(
751+
subtree_com1_weld - pos1_weld,
752+
cfrc_weld_force,
753+
cfrc_weld_torque,
754+
k1_weld_mask,
755+
)
756+
757+
# body 2
758+
k2_weld = body2id[weld_id]
759+
k2_weld_mask = -1 * (k2_weld != 0)
760+
offset2_weld = pos2[weld_id]
761+
762+
pos2_weld = jax.vmap(lambda pnt, mat, vec: mat @ pnt + vec)(
763+
offset2_weld, d.xmat[k2_weld], d.xpos[k2_weld]
764+
)
765+
subtree_com2_weld = d.subtree_com[jp.array(m.body_rootid)[k2_weld]]
766+
cfrc_com2_weld = jax.vmap(
767+
lambda dif, frc, trq, mask: mask
768+
* jp.concatenate([trq - jp.cross(dif, frc), frc])
769+
)(
770+
subtree_com2_weld - pos2_weld,
771+
cfrc_weld_force,
772+
cfrc_weld_torque,
773+
k2_weld_mask,
774+
)
775+
776+
cfrc_ext_equality.append(jp.vstack([cfrc_com1_weld, cfrc_com2_weld]))
777+
cfrc_ext_equality_adr.append(jp.concatenate([k1_weld, k2_weld]))
778+
779+
if nconnect or nweld:
780+
cfrc_ext = cfrc_ext.at[jp.concatenate(cfrc_ext_equality_adr)].add(
781+
jp.vstack(cfrc_ext_equality)
782+
)
667783

668784
# forward pass over bodies: compute cacc, cfrc_int
669785
def _forward(carry, cfrc_ext, cinert, cvel, body_dofadr, body_dofnum):

mjx/mujoco/mjx/_src/smooth_test.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,21 +230,98 @@ def test_subtree_vel(self):
230230

231231

232232
class RnePostConstraintTest(parameterized.TestCase):
233+
_CONNECT_SITE = """
234+
<equality>
235+
<connect site1="site1" site2="site2"/>
236+
</equality>
237+
"""
238+
_CONNECT_BODY = """
239+
<equality>
240+
<connect body1="body1" body2="body2" anchor="1 2 3"/>
241+
</equality>
242+
"""
243+
_WELD_SITE = """
244+
<equality>
245+
<weld site1="site1" site2="site2"/>
246+
</equality>
247+
"""
248+
_WELD_BODY = """
249+
<equality>
250+
<weld body1="body1" body2="body2"/>
251+
</equality>
252+
"""
253+
_CONNECT_SITE_WELD_SITE = """
254+
<equality>
255+
<connect site1="site1" site2="site2"/>
256+
<weld site1="site1" site2="site2"/>
257+
</equality>
258+
"""
259+
_WELD_SITE_CONNECT_SITE = """
260+
<equality>
261+
<weld site1="site1" site2="site2"/>
262+
<connect site1="site1" site2="site2"/>
263+
</equality>
264+
"""
265+
_WELD_SITE_CONNECT_SITE_WELD_BODY = """
266+
<equality>
267+
<weld site1="site1" site2="site2"/>
268+
<connect site1="site1" site2="site2"/>
269+
<weld body1="body1" body2="body2"/>
270+
</equality>
271+
"""
272+
_CONNECT_SITE_WELD_SITE_WELD_BODY = """
273+
<equality>
274+
<connect site1="site1" site2="site2"/>
275+
<weld site1="site1" site2="site2"/>
276+
<weld body1="body1" body2="body2"/>
277+
</equality>
278+
"""
279+
_CONNECT_SITE_CONNECT_BODY_CONNECT_WELD = """
280+
<equality>
281+
<connect site1="site1" site2="site2"/>
282+
<connect body1="body1" body2="body2" anchor="1 2 3"/>
283+
<weld body1="body1" body2="body2"/>
284+
</equality>
285+
"""
233286

234-
@parameterized.parameters(ConeType)
235-
def test_rnepostconstraint(self, cone_type):
287+
@parameterized.parameters(
288+
('', ConeType.PYRAMIDAL, None),
289+
('', ConeType.ELLIPTIC, None),
290+
(_CONNECT_SITE, ConeType.PYRAMIDAL, None),
291+
(_CONNECT_BODY, ConeType.PYRAMIDAL, None),
292+
(_WELD_SITE, ConeType.PYRAMIDAL, None),
293+
(_WELD_BODY, ConeType.PYRAMIDAL, None),
294+
(_CONNECT_SITE_WELD_SITE, ConeType.PYRAMIDAL, None),
295+
(
296+
_WELD_SITE_CONNECT_SITE,
297+
ConeType.PYRAMIDAL,
298+
np.array([6, 7, 8, 0, 1, 2, 3, 4, 5]),
299+
),
300+
(
301+
_WELD_SITE_CONNECT_SITE_WELD_BODY,
302+
ConeType.PYRAMIDAL,
303+
np.array([6, 7, 8, 0, 1, 2, 3, 4, 5]),
304+
),
305+
(_CONNECT_SITE_WELD_SITE_WELD_BODY, ConeType.PYRAMIDAL, None),
306+
(_CONNECT_SITE_CONNECT_BODY_CONNECT_WELD, ConeType.PYRAMIDAL, None),
307+
)
308+
def test_rnepostconstraint(self, equality, cone_type, efc_map):
236309
"""Tests MJX rne_postconstraint function to match MuJoCo mj_rnePostConstraint."""
237310

238-
m = mujoco.MjModel.from_xml_string("""
311+
m = mujoco.MjModel.from_xml_string(f"""
239312
<mujoco>
240313
<worldbody>
241314
<geom name="floor" size="10 10 .05" type="plane"/>
242-
<body pos="0 0 1">
315+
<site name="site1"/>
316+
<body name="body1">
317+
</body>
318+
<body pos="0 0 1" name="body2">
243319
<joint type="ball" damping="1"/>
244320
<geom type="capsule" size="0.1 0.5" fromto="0 0 0 0.5 0 0" condim="1"/>
245321
<body pos="0.5 0 0">
246322
<joint type="ball" damping="1"/>
247323
<geom type="capsule" size="0.1 0.5" fromto="0 0 0 0.5 0 0" condim="3"/>
324+
<site name="site2"/>
248325
</body>
249326
</body>
250327
<body pos="0 1 1">
@@ -256,6 +333,7 @@ def test_rnepostconstraint(self, cone_type):
256333
</body>
257334
</body>
258335
</worldbody>
336+
{equality}
259337
<keyframe>
260338
<key qpos='0.424577 0.450592 0.451703 -0.642391 0.729379 0.545151 0.407756 0.0674697 0.424577 1.450592 0.451703 -0.642391 0.729379 0.545151 0.407756 0.0674697'/>
261339
</keyframe>
@@ -273,6 +351,11 @@ def test_rnepostconstraint(self, cone_type):
273351
mx = mjx.put_model(m)
274352
dx = mjx.put_data(m, d)
275353

354+
if efc_map is not None:
355+
efc_force = d.efc_force.copy()
356+
efc_force[: len(efc_map)] = d.efc_force[efc_map]
357+
dx = dx.tree_replace({'_impl.efc_force': jp.array(efc_force)})
358+
276359
# rne postconstraint
277360
mujoco.mj_rnePostConstraint(m, d)
278361
dx = jax.jit(mjx.rne_postconstraint)(mx, dx)

0 commit comments

Comments
 (0)