Skip to content

Commit 3c107b3

Browse files
btabacopybara-github
authored andcommitted
Add stage_in and stage_out for MJX-Warp.
PiperOrigin-RevId: 868180029 Change-Id: I0d1b17b6023d32498766b5d1d9e5624bd31cda3d
1 parent 7ef3ae6 commit 3c107b3

File tree

7 files changed

+730
-583
lines changed

7 files changed

+730
-583
lines changed

mjx/mujoco/mjx/warp/collision_driver.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,6 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
233233
'collision_pair': d._impl.collision_pair.shape,
234234
'collision_pairid': d._impl.collision_pairid.shape,
235235
'collision_worldid': d._impl.collision_worldid.shape,
236-
'geom_xmat': d.geom_xmat.shape,
237-
'geom_xpos': d.geom_xpos.shape,
238236
'nacon': d._impl.nacon.shape,
239237
'ncollision': d._impl.ncollision.shape,
240238
'contact__dim': d._impl.contact__dim.shape,
@@ -253,15 +251,13 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
253251
}
254252
jf = ffi.jax_callable_variadic_tuple(
255253
_collision_shim,
256-
num_outputs=20,
254+
num_outputs=18,
257255
output_dims=output_dims,
258256
vmap_method=None,
259257
in_out_argnames={
260258
'collision_pair',
261259
'collision_pairid',
262260
'collision_worldid',
263-
'geom_xmat',
264-
'geom_xpos',
265261
'nacon',
266262
'ncollision',
267263
'contact__dim',
@@ -278,6 +274,27 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
278274
'contact__type',
279275
'contact__worldid',
280276
},
277+
stage_in_argnames={
278+
'geom_aabb',
279+
'geom_friction',
280+
'geom_gap',
281+
'geom_margin',
282+
'geom_rbound',
283+
'geom_size',
284+
'geom_solimp',
285+
'geom_solmix',
286+
'geom_solref',
287+
'geom_xmat',
288+
'geom_xpos',
289+
'hfield_data',
290+
'pair_friction',
291+
'pair_gap',
292+
'pair_margin',
293+
'pair_solimp',
294+
'pair_solref',
295+
'pair_solreffriction',
296+
},
297+
stage_out_argnames={},
281298
graph_mode=m.opt._impl.graph_mode,
282299
)
283300
out = jf(
@@ -373,23 +390,21 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
373390
'_impl.collision_pair': out[0],
374391
'_impl.collision_pairid': out[1],
375392
'_impl.collision_worldid': out[2],
376-
'geom_xmat': out[3],
377-
'geom_xpos': out[4],
378-
'_impl.nacon': out[5],
379-
'_impl.ncollision': out[6],
380-
'_impl.contact__dim': out[7],
381-
'_impl.contact__dist': out[8],
382-
'_impl.contact__frame': out[9],
383-
'_impl.contact__friction': out[10],
384-
'_impl.contact__geom': out[11],
385-
'_impl.contact__geomcollisionid': out[12],
386-
'_impl.contact__includemargin': out[13],
387-
'_impl.contact__pos': out[14],
388-
'_impl.contact__solimp': out[15],
389-
'_impl.contact__solref': out[16],
390-
'_impl.contact__solreffriction': out[17],
391-
'_impl.contact__type': out[18],
392-
'_impl.contact__worldid': out[19],
393+
'_impl.nacon': out[3],
394+
'_impl.ncollision': out[4],
395+
'_impl.contact__dim': out[5],
396+
'_impl.contact__dist': out[6],
397+
'_impl.contact__frame': out[7],
398+
'_impl.contact__friction': out[8],
399+
'_impl.contact__geom': out[9],
400+
'_impl.contact__geomcollisionid': out[10],
401+
'_impl.contact__includemargin': out[11],
402+
'_impl.contact__pos': out[12],
403+
'_impl.contact__solimp': out[13],
404+
'_impl.contact__solref': out[14],
405+
'_impl.contact__solreffriction': out[15],
406+
'_impl.contact__type': out[16],
407+
'_impl.contact__worldid': out[17],
393408
})
394409
return d
395410

mjx/mujoco/mjx/warp/ffi.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def jax_callable_variadic_tuple(
101101
vmap_method: Optional[str] = None,
102102
output_dims: Optional[dict[str, tuple[int, ...]]] = None,
103103
in_out_argnames: Optional[Sequence[str]] = None,
104+
stage_in_argnames: Optional[Sequence[str]] = None,
105+
stage_out_argnames: Optional[Sequence[str]] = None,
104106
):
105107
"""Wraps a JAX callable to support variadic tuples and dataclasses."""
106108

@@ -130,6 +132,8 @@ def func_wrapper(*flat_args, **kwargs):
130132
vmap_method=vmap_method,
131133
output_dims=output_dims,
132134
in_out_argnames=in_out_argnames,
135+
stage_in_argnames=stage_in_argnames,
136+
stage_out_argnames=stage_out_argnames,
133137
)
134138

135139
flat_args, in_tree = jax.tree.flatten(args)

0 commit comments

Comments
 (0)