Skip to content

Commit 6cfea71

Browse files
btabacopybara-github
authored andcommitted
Allow for different MJX backend implementations.
PiperOrigin-RevId: 755935704 Change-Id: Ic135cd00137c2857c73c683ed9fdc5ac4418715d
1 parent 421c487 commit 6cfea71

24 files changed

+2183
-1505
lines changed

mjx/mujoco/mjx/_src/collision_convex.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
from mujoco.mjx._src.collision_types import GeomInfo
2929
from mujoco.mjx._src.collision_types import HFieldInfo
3030
from mujoco.mjx._src.types import Data
31+
from mujoco.mjx._src.types import DataJAX
3132
from mujoco.mjx._src.types import GeomType
3233
from mujoco.mjx._src.types import Model
34+
from mujoco.mjx._src.types import ModelJAX
3335
# pylint: enable=g-importing-member
3436

3537
_GeomInfo = Union[GeomInfo, ConvexInfo]
@@ -42,6 +44,9 @@ def wrapper(collision_fn):
4244
def collide(
4345
m: Model, d: Data, key: FunctionKey, geom: jax.Array
4446
) -> Collision:
47+
if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX):
48+
raise ValueError('collider requires JAX backend implementation.')
49+
4550
g1, g2 = geom.T
4651
infos = [
4752
GeomInfo(d.geom_xpos[g1], d.geom_xmat[g1], m.geom_size[g1]),
@@ -56,7 +61,7 @@ def collide(
5661
pos=0, mat=0, size=0, face=0, vert=0
5762
)
5863
elif key.types[i] == GeomType.MESH:
59-
c, cm = infos[i], m.mesh_convex[key.data_ids[i]]
64+
c, cm = infos[i], m._impl.mesh_convex[key.data_ids[i]]
6065
infos[i] = ConvexInfo(**vars(c), **vars(cm))
6166
in_axes[i] = jax.tree_util.tree_map(lambda x: None, infos[i]).replace(
6267
pos=0, mat=0, size=0

mjx/mujoco/mjx/_src/collision_driver.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@
7171
from mujoco.mjx._src.collision_types import FunctionKey
7272
from mujoco.mjx._src.types import Contact
7373
from mujoco.mjx._src.types import Data
74+
from mujoco.mjx._src.types import DataJAX
7475
from mujoco.mjx._src.types import DisableBit
7576
from mujoco.mjx._src.types import GeomType
7677
from mujoco.mjx._src.types import Model
78+
from mujoco.mjx._src.types import ModelJAX
7779
# pylint: enable=g-importing-member
7880
import numpy as np
7981

@@ -227,7 +229,7 @@ def _geom_groups(
227229
if types[0] == mujoco.mjtGeom.mjGEOM_HFIELD:
228230
# add static grid bounds to the grouping key for hfield collisions
229231
geom_rbound_hfield = (
230-
m.geom_rbound_hfield if isinstance(m, Model) else m.geom_rbound
232+
m._impl.geom_rbound_hfield if isinstance(m, Model) else m.geom_rbound # pytype: disable=attribute-error
231233
)
232234
nrow, ncol = m.hfield_nrow[data_ids[0]], m.hfield_ncol[data_ids[0]]
233235
xsize, ysize = m.hfield_size[data_ids[0]][:2]
@@ -323,11 +325,11 @@ def _contact_groups(m: Model, d: Data) -> Dict[FunctionKey, Contact]:
323325
solref=solref,
324326
solreffriction=solreffriction,
325327
solimp=solimp,
326-
dim=d.contact.dim,
328+
dim=d._impl.contact.dim, # pytype: disable=attribute-error
327329
geom1=jp.array(geom[:, 0]),
328330
geom2=jp.array(geom[:, 1]),
329331
geom=jp.array(geom[:, :2]),
330-
efc_address=d.contact.efc_address,
332+
efc_address=d._impl.contact.efc_address, # pytype: disable=attribute-error
331333
)
332334

333335
return groups
@@ -374,7 +376,10 @@ def make_condim(m: Union[Model, mujoco.MjModel]) -> np.ndarray:
374376

375377
def collision(m: Model, d: Data) -> Data:
376378
"""Collides geometries."""
377-
if d.ncon == 0:
379+
if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX):
380+
raise ValueError('collision requires JAX backend implementation.')
381+
382+
if d._impl.ncon == 0: # pytype: disable=attribute-error
378383
return d
379384

380385
max_geom_pairs = _numeric(m, 'max_geom_pairs')
@@ -424,4 +429,4 @@ def collision(m: Model, d: Data) -> Data:
424429
contacts = sum([condim_groups[k] for k in sorted(condim_groups)], [])
425430
contact = jax.tree_util.tree_map(lambda *x: jp.concatenate(x), *contacts)
426431

427-
return d.replace(contact=contact)
432+
return d.replace(_impl=d._impl.replace(contact=contact)) # pytype: disable=attribute-error

0 commit comments

Comments
 (0)