Skip to content

Commit 2c3d0be

Browse files
erikfreycopybara-github
authored andcommitted
Empty out Model fields restricted to MuJoCo in the same way we do for Data fields.
PiperOrigin-RevId: 719705508 Change-Id: I4b07433481eafaea1d95a5bdd8fd83786939188f
1 parent 2546fce commit 2c3d0be

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

mjx/mujoco/mjx/_src/io.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def f(leaf):
3939

4040

4141
def _make_option(
42-
o: mujoco.MjOption, _full_compat: bool = False
42+
o: mujoco.MjOption, _full_compat: bool = False # pylint: disable=invalid-name
4343
) -> types.Option:
4444
"""Returns mjx.Option given mujoco.MjOption."""
4545
if not _full_compat:
@@ -183,6 +183,15 @@ def put_model(
183183
if f.metadata.get('restricted_to') != 'mjx'
184184
}
185185
fields = {f: getattr(m, f) for f in mj_field_names}
186+
187+
# zero out fields restricted to MuJoCo
188+
if not _full_compat:
189+
for f in types.Model.fields():
190+
if f.metadata.get('restricted_to') == 'mujoco' and isinstance(
191+
fields[f.name], np.ndarray
192+
):
193+
fields[f.name] = np.zeros((0,), dtype=fields[f.name].dtype)
194+
186195
fields['dof_hasfrictionloss'] = fields['dof_frictionloss'] > 0
187196
fields['tendon_hasfrictionloss'] = fields['tendon_frictionloss'] > 0
188197
fields['geom_rbound_hfield'] = fields['geom_rbound']
@@ -522,7 +531,7 @@ def _make_contact(
522531
# if we have fewer Contacts for a condim range, pad the range with zeros
523532

524533
# build a map for where to find a dim-matching contact, or -1 if none
525-
contact_map = np.zeros_like(dim) - 1
534+
contact_map = -np.ones_like(dim)
526535
for i, di in enumerate(fields['dim']):
527536
space = [j for j, dj in enumerate(dim) if di == dj and contact_map[j] == -1]
528537
if not space:
@@ -672,7 +681,9 @@ def put_data(
672681
fields['_qLDiagInv_sparse'] = jp.zeros(0, dtype=float)
673682
# otherwise clear out unused arrays
674683
for f in types.Data.fields():
675-
if f.metadata.get('restricted_to') == 'mujoco':
684+
if f.metadata.get('restricted_to') == 'mujoco' and isinstance(
685+
fields[f.name], np.ndarray
686+
):
676687
fields[f.name] = np.zeros(0, dtype=fields[f.name].dtype)
677688

678689
fields['contact'] = contact

mjx/mujoco/mjx/_src/io_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from jax import numpy as jp
2121
import mujoco
2222
from mujoco import mjx
23+
# pylint: disable=g-importing-member
2324
from mujoco.mjx._src.types import ConeType
25+
# pylint: enable=g-importing-member
2426
import numpy as np
2527

2628

@@ -117,6 +119,9 @@ def assert_not_weak_type(x):
117119
self.assertEqual(mx.nM, m.nM)
118120
self.assertAlmostEqual(mx.opt.timestep, m.opt.timestep)
119121

122+
# fields restricted to MuJoCo should not be populated
123+
self.assertEqual(mx.bvh_aabb.shape, (0,))
124+
120125
np.testing.assert_allclose(mx.body_parentid, m.body_parentid)
121126
np.testing.assert_allclose(mx.geom_type, m.geom_type)
122127
np.testing.assert_allclose(mx.geom_bodyid, m.geom_bodyid)

0 commit comments

Comments
 (0)