|
71 | 71 | from mujoco.mjx._src.collision_types import FunctionKey |
72 | 72 | from mujoco.mjx._src.types import Contact |
73 | 73 | from mujoco.mjx._src.types import Data |
| 74 | +from mujoco.mjx._src.types import DataJAX |
74 | 75 | from mujoco.mjx._src.types import DisableBit |
75 | 76 | from mujoco.mjx._src.types import GeomType |
76 | 77 | from mujoco.mjx._src.types import Model |
| 78 | +from mujoco.mjx._src.types import ModelJAX |
77 | 79 | # pylint: enable=g-importing-member |
78 | 80 | import numpy as np |
79 | 81 |
|
@@ -227,7 +229,7 @@ def _geom_groups( |
227 | 229 | if types[0] == mujoco.mjtGeom.mjGEOM_HFIELD: |
228 | 230 | # add static grid bounds to the grouping key for hfield collisions |
229 | 231 | geom_rbound_hfield = ( |
230 | | - m.geom_rbound_hfield if isinstance(m, Model) else m.geom_rbound |
| 232 | + m._impl.geom_rbound_hfield if isinstance(m, Model) else m.geom_rbound # pytype: disable=attribute-error |
231 | 233 | ) |
232 | 234 | nrow, ncol = m.hfield_nrow[data_ids[0]], m.hfield_ncol[data_ids[0]] |
233 | 235 | xsize, ysize = m.hfield_size[data_ids[0]][:2] |
@@ -323,11 +325,11 @@ def _contact_groups(m: Model, d: Data) -> Dict[FunctionKey, Contact]: |
323 | 325 | solref=solref, |
324 | 326 | solreffriction=solreffriction, |
325 | 327 | solimp=solimp, |
326 | | - dim=d.contact.dim, |
| 328 | + dim=d._impl.contact.dim, # pytype: disable=attribute-error |
327 | 329 | geom1=jp.array(geom[:, 0]), |
328 | 330 | geom2=jp.array(geom[:, 1]), |
329 | 331 | geom=jp.array(geom[:, :2]), |
330 | | - efc_address=d.contact.efc_address, |
| 332 | + efc_address=d._impl.contact.efc_address, # pytype: disable=attribute-error |
331 | 333 | ) |
332 | 334 |
|
333 | 335 | return groups |
@@ -374,7 +376,10 @@ def make_condim(m: Union[Model, mujoco.MjModel]) -> np.ndarray: |
374 | 376 |
|
375 | 377 | def collision(m: Model, d: Data) -> Data: |
376 | 378 | """Collides geometries.""" |
377 | | - if d.ncon == 0: |
| 379 | + if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): |
| 380 | + raise ValueError('collision requires JAX backend implementation.') |
| 381 | + |
| 382 | + if d._impl.ncon == 0: # pytype: disable=attribute-error |
378 | 383 | return d |
379 | 384 |
|
380 | 385 | max_geom_pairs = _numeric(m, 'max_geom_pairs') |
@@ -424,4 +429,4 @@ def collision(m: Model, d: Data) -> Data: |
424 | 429 | contacts = sum([condim_groups[k] for k in sorted(condim_groups)], []) |
425 | 430 | contact = jax.tree_util.tree_map(lambda *x: jp.concatenate(x), *contacts) |
426 | 431 |
|
427 | | - return d.replace(contact=contact) |
| 432 | + return d.replace(_impl=d._impl.replace(contact=contact)) # pytype: disable=attribute-error |
0 commit comments