Skip to content

Commit 60cb976

Browse files
btabacopybara-github
authored andcommitted
Add assertion for leading dims in MJX-Warp.
PiperOrigin-RevId: 808624659 Change-Id: I259f9cd753bdffbe465ae807ec957f123e8d589e
1 parent eb96dd5 commit 60cb976

File tree

2 files changed

+82
-6
lines changed

2 files changed

+82
-6
lines changed

mjx/mujoco/mjx/warp/ffi.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,8 @@ def wrapper(*args):
183183
return wrapper
184184

185185

186-
def _get_mapping_from_tree_path(
187-
path: jax.tree_util.KeyPath,
188-
mapping: dict[str, int],
189-
) -> Optional[int]:
190-
"""Gets the mapped value from a tree path."""
186+
def _tree_path_to_attr_str(path: jax.tree_util.KeyPath) -> str:
187+
"""Converts a tree path to a dataclass attribute string."""
191188
if not isinstance(path, tuple):
192189
raise NotImplementedError(
193190
f'Parsing for jax tree path {path} not implemented.'
@@ -200,8 +197,15 @@ def _get_mapping_from_tree_path(
200197

201198
assert all(isinstance(p, jax.tree_util.GetAttrKey) for p in path)
202199
path = [p for p in path if p.name != '_impl']
203-
attr = '__'.join(p.name for p in path)
200+
return '__'.join(p.name for p in path)
201+
204202

203+
def _get_mapping_from_tree_path(
204+
path: jax.tree_util.KeyPath,
205+
mapping: dict[str, int],
206+
) -> Optional[int]:
207+
"""Gets the mapped value from a tree path."""
208+
attr = _tree_path_to_attr_str(path)
205209
# None if the MJX public field is not present in the MJX-Warp mapping.
206210
return mapping.get(attr)
207211

@@ -302,6 +306,43 @@ def _maybe_broadcast_to(
302306
return leaf
303307

304308

309+
def _check_leading_dim(
310+
path: jax.tree_util.KeyPath,
311+
leaf: Any,
312+
expected_batch_dim: int,
313+
expected_nconmax: int,
314+
expected_njmax: int,
315+
):
316+
"""Asserts that the batch dimension of a leaf node matches the expected batch dimension."""
317+
has_batch_dim = _get_mapping_from_tree_path(
318+
path, mjx_warp_types._BATCH_DIM['Data']
319+
)
320+
attr = _tree_path_to_attr_str(path)
321+
if has_batch_dim and leaf.shape[0] != expected_batch_dim:
322+
raise ValueError(
323+
f'Leaf node batch size ({leaf.shape[0]}) and expected batch size'
324+
f' ({expected_batch_dim}) do not match for field {attr}.'
325+
)
326+
if (
327+
not has_batch_dim
328+
and attr.startswith('contact__')
329+
and leaf.shape[0] != expected_nconmax
330+
):
331+
raise ValueError(
332+
f'Leaf node leading dim ({leaf.shape[0]}) does not match nconmax'
333+
f' ({expected_nconmax}) for field {attr}.'
334+
)
335+
if (
336+
not has_batch_dim
337+
and attr.startswith('efc__')
338+
and leaf.shape[0] != expected_njmax
339+
):
340+
raise ValueError(
341+
f'Leaf node leading dim ({leaf.shape[0]}) does not match njmax'
342+
f' ({expected_njmax}) for field {attr}.'
343+
)
344+
345+
305346
def marshal_custom_vmap(vmap_func):
306347
"""Marshal fields for a custom vmap into an MuJoCo Warp function."""
307348

@@ -316,6 +357,13 @@ def wrapper(axis_size, is_batched, m, d):
316357
),
317358
d, is_batched[1], # fmt: skip
318359
)
360+
# Check leading dimensions.
361+
jax.tree.map_with_path(
362+
lambda path, x: _check_leading_dim(
363+
path, x, d_broadcast.qpos.shape[0], d._impl.nconmax, d._impl.njmax # pylint: disable=protected-access
364+
),
365+
d_broadcast,
366+
)
319367
# Flatten batch dims into the first axis if the vmap was nested.
320368
m_flat = jax.tree.map_with_path(
321369
lambda path, x: _flatten_batch_dim(

mjx/mujoco/mjx/warp/forward_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,34 @@ def test_step(self, xml: str, batch_size: int):
261261
tu.assert_attr_eq(dx, d, 'mocap_quat')
262262
tu.assert_attr_eq(dx, d, 'sensordata')
263263

264+
def test_step_leading_dim_mismatch(self):
265+
if not _FORCE_TEST:
266+
if not mjxw.WARP_INSTALLED:
267+
self.skipTest('Warp not installed.')
268+
if not io.has_cuda_gpu_device():
269+
self.skipTest('No CUDA GPU device available.')
270+
271+
xml = 'humanoid/humanoid.xml'
272+
batch_size = 7
273+
274+
m = test_util.load_test_file(xml)
275+
mx = mjx.put_model(m, impl='warp')
276+
277+
worldids = jp.arange(batch_size)
278+
dx_batch = jax.vmap(functools.partial(tu.make_data, m))(worldids)
279+
dx_batch_orig = dx_batch
280+
281+
with self.assertRaises(ValueError):
282+
dx_batch = dx_batch.replace(qpos=dx_batch.qpos[1:])
283+
_ = jax.jit(jax.vmap(forward.step, in_axes=(None, 0)))(mx, dx_batch)
284+
285+
dx_batch = dx_batch_orig
286+
with self.assertRaises(ValueError):
287+
dx_batch = dx_batch.tree_replace(
288+
{'_impl.contact__pos': dx_batch._impl.contact__pos[1:]}
289+
)
290+
_ = jax.jit(jax.vmap(forward.step, in_axes=(None, 0)))(mx, dx_batch)
291+
264292

265293
if __name__ == '__main__':
266294
absltest.main()

0 commit comments

Comments
 (0)