@@ -183,11 +183,8 @@ def wrapper(*args):
183
183
return wrapper
184
184
185
185
186
- def _get_mapping_from_tree_path (
187
- path : jax .tree_util .KeyPath ,
188
- mapping : dict [str , int ],
189
- ) -> Optional [int ]:
190
- """Gets the mapped value from a tree path."""
186
+ def _tree_path_to_attr_str (path : jax .tree_util .KeyPath ) -> str :
187
+ """Converts a tree path to a dataclass attribute string."""
191
188
if not isinstance (path , tuple ):
192
189
raise NotImplementedError (
193
190
f'Parsing for jax tree path { path } not implemented.'
@@ -200,8 +197,15 @@ def _get_mapping_from_tree_path(
200
197
201
198
assert all (isinstance (p , jax .tree_util .GetAttrKey ) for p in path )
202
199
path = [p for p in path if p .name != '_impl' ]
203
- attr = '__' .join (p .name for p in path )
200
+ return '__' .join (p .name for p in path )
201
+
204
202
203
+ def _get_mapping_from_tree_path (
204
+ path : jax .tree_util .KeyPath ,
205
+ mapping : dict [str , int ],
206
+ ) -> Optional [int ]:
207
+ """Gets the mapped value from a tree path."""
208
+ attr = _tree_path_to_attr_str (path )
205
209
# None if the MJX public field is not present in the MJX-Warp mapping.
206
210
return mapping .get (attr )
207
211
@@ -302,6 +306,43 @@ def _maybe_broadcast_to(
302
306
return leaf
303
307
304
308
309
+ def _check_leading_dim (
310
+ path : jax .tree_util .KeyPath ,
311
+ leaf : Any ,
312
+ expected_batch_dim : int ,
313
+ expected_nconmax : int ,
314
+ expected_njmax : int ,
315
+ ):
316
+ """Asserts that the batch dimension of a leaf node matches the expected batch dimension."""
317
+ has_batch_dim = _get_mapping_from_tree_path (
318
+ path , mjx_warp_types ._BATCH_DIM ['Data' ]
319
+ )
320
+ attr = _tree_path_to_attr_str (path )
321
+ if has_batch_dim and leaf .shape [0 ] != expected_batch_dim :
322
+ raise ValueError (
323
+ f'Leaf node batch size ({ leaf .shape [0 ]} ) and expected batch size'
324
+ f' ({ expected_batch_dim } ) do not match for field { attr } .'
325
+ )
326
+ if (
327
+ not has_batch_dim
328
+ and attr .startswith ('contact__' )
329
+ and leaf .shape [0 ] != expected_nconmax
330
+ ):
331
+ raise ValueError (
332
+ f'Leaf node leading dim ({ leaf .shape [0 ]} ) does not match nconmax'
333
+ f' ({ expected_nconmax } ) for field { attr } .'
334
+ )
335
+ if (
336
+ not has_batch_dim
337
+ and attr .startswith ('efc__' )
338
+ and leaf .shape [0 ] != expected_njmax
339
+ ):
340
+ raise ValueError (
341
+ f'Leaf node leading dim ({ leaf .shape [0 ]} ) does not match njmax'
342
+ f' ({ expected_njmax } ) for field { attr } .'
343
+ )
344
+
345
+
305
346
def marshal_custom_vmap (vmap_func ):
306
347
"""Marshal fields for a custom vmap into an MuJoCo Warp function."""
307
348
@@ -316,6 +357,13 @@ def wrapper(axis_size, is_batched, m, d):
316
357
),
317
358
d , is_batched [1 ], # fmt: skip
318
359
)
360
+ # Check leading dimensions.
361
+ jax .tree .map_with_path (
362
+ lambda path , x : _check_leading_dim (
363
+ path , x , d_broadcast .qpos .shape [0 ], d ._impl .nconmax , d ._impl .njmax # pylint: disable=protected-access
364
+ ),
365
+ d_broadcast ,
366
+ )
319
367
# Flatten batch dims into the first axis if the vmap was nested.
320
368
m_flat = jax .tree .map_with_path (
321
369
lambda path , x : _flatten_batch_dim (
0 commit comments