Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 72 additions & 12 deletions mjx/mujoco/mjx/_src/support_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
from absl.testing import parameterized
import jax
from jax import numpy as jp
import numpy as np

import mujoco
from mujoco import mjx
from mujoco.mjx._src import support
from mujoco.mjx._src import test_util
import numpy as np


class SupportTest(parameterized.TestCase):
Expand Down Expand Up @@ -210,9 +211,7 @@ def test_bind(self):
np.testing.assert_array_equal(m.bind(s.bodies[i]).pos, m.body_pos[i, :])
np.testing.assert_array_equal(mx.bind(s.bodies[i]).pos, m.body_pos[i, :])
np.testing.assert_array_equal(d.bind(s.bodies[i]).xpos, d.xpos[i, :])
np.testing.assert_array_equal(
dx.bind(mx, s.bodies[i]).xpos, d.xpos[i, :]
)
np.testing.assert_array_equal(dx.bind(mx, s.bodies[i]).xpos, d.xpos[i, :])
np.testing.assert_array_equal(
dx.bind(mx, s.bodies[i]).xfrc_applied, d.xfrc_applied[i, :]
)
Expand All @@ -239,15 +238,18 @@ def test_bind(self):
np.testing.assert_array_equal(mx.bind(s.joints[i]).axis, m.jnt_axis[i, :])
np.testing.assert_array_almost_equal(
dx.bind(mx, s.joints[i]).qpos,
d.qpos[m.jnt_qposadr[i]:m.jnt_qposadr[i] + qposnum[i]], decimal=6
d.qpos[m.jnt_qposadr[i] : m.jnt_qposadr[i] + qposnum[i]],
decimal=6,
)
np.testing.assert_array_almost_equal(
dx.bind(mx, s.joints[i]).qvel,
d.qvel[m.jnt_dofadr[i]:m.jnt_dofadr[i] + dofnum[i]], decimal=6
d.qvel[m.jnt_dofadr[i] : m.jnt_dofadr[i] + dofnum[i]],
decimal=6,
)
np.testing.assert_array_almost_equal(
dx.bind(mx, s.joints[i]).qacc,
d.qacc[m.jnt_dofadr[i]:m.jnt_dofadr[i] + dofnum[i]], decimal=6
d.qacc[m.jnt_dofadr[i] : m.jnt_dofadr[i] + dofnum[i]],
decimal=6,
)
np.testing.assert_array_almost_equal(
dx.bind(mx, s.joints[i]).qfrc_actuator,
Expand All @@ -258,9 +260,7 @@ def test_bind(self):
np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, d.ctrl)
for i in range(m.nu):
np.testing.assert_array_equal(d.bind(s.actuators[i]).ctrl, d.ctrl[i])
np.testing.assert_array_equal(
dx.bind(mx, s.actuators[i]).ctrl, d.ctrl[i]
)
np.testing.assert_array_equal(dx.bind(mx, s.actuators[i]).ctrl, d.ctrl[i])

np.testing.assert_array_equal(
dx.bind(mx, s.sensors).sensordata, d.sensordata
Expand Down Expand Up @@ -363,7 +363,8 @@ def test_bind(self):
batch_size = 16
ds = [d for _ in range(batch_size)]
vdx = jax.vmap(lambda xpos: dx.replace(xpos=xpos))(
jp.array([d.xpos for d in ds], device=jax.devices('cpu')[0]))
jp.array([d.xpos for d in ds], device=jax.devices('cpu')[0])
)
for i in range(m.nbody):
np.testing.assert_array_equal(
vdx.bind(mx, s.bodies[i]).xpos, [d.xpos[i, :]] * batch_size
Expand Down Expand Up @@ -488,7 +489,9 @@ def test_wrap_inside(self):
# radius < mjMINVAL
np.testing.assert_equal(
support.wrap_inside(
jp.array([1, 0, 0, 1]), jp.array([0.1 * mujoco.mjMINVAL]), maxiter,
jp.array([1, 0, 0, 1]),
jp.array([0.1 * mujoco.mjMINVAL]),
maxiter,
tolerance,
z_init,
)[0],
Expand Down Expand Up @@ -782,6 +785,63 @@ def _muscle_dynamics_millard(ctrl, act, prm):
atol=1e-5,
)

def test_model_named_accessors(self):
"""Tests Model.body(), Model.joint(), Model.geom(), etc."""
xml = """
<mujoco model="test_model">
<worldbody>
<geom name="plane" type="plane" size="1 1 1"/>
<body name="body1" pos="0 0 1">
<joint name="joint1" type="slide" axis="1 0 0" range="-5 5"/>
<geom name="box1" type="box" size=".2 .1 .1" rgba=".9 .3 .3 1"/>
<site name="site1" size="0.01"/>
</body>
</worldbody>
<actuator>
<motor joint="joint1" name="motor1"/>
</actuator>
</mujoco>
"""
m = mujoco.MjModel.from_xml_string(xml)
mx = mjx.put_model(m)

# Test body accessor
body = mx.body('body1')
self.assertEqual(body.id, m.body('body1').id)
self.assertEqual(body.name, 'body1')

# Test joint accessor
joint = mx.joint('joint1')
self.assertEqual(joint.id, m.joint('joint1').id)
self.assertEqual(joint.name, 'joint1')

# Test geom accessor
geom = mx.geom('box1')
self.assertEqual(geom.id, m.geom('box1').id)
self.assertEqual(geom.name, 'box1')

# Test site accessor
site = mx.site('site1')
self.assertEqual(site.id, m.site('site1').id)
self.assertEqual(site.name, 'site1')

# Test actuator accessor
actuator = mx.actuator('motor1')
self.assertEqual(actuator.id, m.actuator('motor1').id)
self.assertEqual(actuator.name, 'motor1')

# Test KeyError for non-existent elements
with self.assertRaises(KeyError):
mx.body('nonexistent')
with self.assertRaises(KeyError):
mx.joint('nonexistent')
with self.assertRaises(KeyError):
mx.geom('nonexistent')
with self.assertRaises(KeyError):
mx.site('nonexistent')
with self.assertRaises(KeyError):
mx.actuator('nonexistent')


if __name__ == '__main__':
absltest.main()
184 changes: 182 additions & 2 deletions mjx/mujoco/mjx/_src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

import dataclasses
import enum
from typing import Any, Tuple, Union
from typing import Any, Optional, Tuple, Union
import warnings

import jax
import numpy as np

import mujoco
from mujoco.mjx._src.dataclasses import PyTreeNode # pylint: disable=g-importing-member
from mujoco.mjx.warp import types as mjxw_types
import numpy as np


class Impl(enum.Enum):
Expand Down Expand Up @@ -533,6 +534,7 @@ class Option(PyTreeNode):

class ModelCPP(PyTreeNode):
"""Minimal Model implementation holding only the pointer."""

# To ensure that we retain the full pointer even if jax.config.enable_x64 is
# set to True, we store the pointer as two 32-bit values. In the FFI call,
# we combine the two values into a single pointer value.
Expand All @@ -543,6 +545,7 @@ class ModelCPP(PyTreeNode):

class DataCPP(PyTreeNode):
"""Minimal Data implementation holding only the pointer."""

# To ensure that we retain the full pointer even if jax.config.enable_x64 is
# set to True, we store the pointer as two 32-bit values. In the FFI call,
# we combine the two values into a single pointer value.
Expand Down Expand Up @@ -672,6 +675,49 @@ class ModelJAX(PyTreeNode):
is_wrap_inside: np.ndarray


@dataclasses.dataclass(frozen=True)
class _ElementView:
"""Base class for named element views in mjx.Model."""

id: int
name: str


@dataclasses.dataclass(frozen=True)
class BodyView(_ElementView):
"""View of a body element by name or id."""

pass


@dataclasses.dataclass(frozen=True)
class JointView(_ElementView):
"""View of a joint element by name or id."""

pass


@dataclasses.dataclass(frozen=True)
class GeomView(_ElementView):
"""View of a geom element by name or id."""

pass


@dataclasses.dataclass(frozen=True)
class SiteView(_ElementView):
"""View of a site element by name or id."""

pass


@dataclasses.dataclass(frozen=True)
class ActuatorView(_ElementView):
"""View of an actuator element by name or id."""

pass


class Model(PyTreeNode):
"""Static model of the scene that remains unchanged with each physics step."""

Expand Down Expand Up @@ -995,6 +1041,140 @@ def __getattr__(self, name: str):
)
return val

def _name2id(self, typ: mujoco._enums.mjtObj, name: str) -> int:
"""Gets the id of an object with the specified mjtObj type and name."""
num_map = {
mujoco.mjtObj.mjOBJ_BODY: self.nbody,
mujoco.mjtObj.mjOBJ_JOINT: self.njnt,
mujoco.mjtObj.mjOBJ_GEOM: self.ngeom,
mujoco.mjtObj.mjOBJ_SITE: self.nsite,
mujoco.mjtObj.mjOBJ_ACTUATOR: self.nu,
}
adr_map = {
mujoco.mjtObj.mjOBJ_BODY: self.name_bodyadr,
mujoco.mjtObj.mjOBJ_JOINT: self.name_jntadr,
mujoco.mjtObj.mjOBJ_GEOM: self.name_geomadr,
mujoco.mjtObj.mjOBJ_SITE: self.name_siteadr,
mujoco.mjtObj.mjOBJ_ACTUATOR: self.name_actuatoradr,
}
num = num_map.get(typ, 0)
adr = adr_map.get(typ)
if adr is None:
return -1
for i in range(num):
obj_name = self.names[adr[i] :].decode('utf-8').split('\x00', 1)[0]
if obj_name == name:
return i
return -1

def _id2name(self, typ: mujoco._enums.mjtObj, i: int) -> Optional[str]:
"""Gets the name of an object with the specified mjtObj type and id."""
num_map = {
mujoco.mjtObj.mjOBJ_BODY: self.nbody,
mujoco.mjtObj.mjOBJ_JOINT: self.njnt,
mujoco.mjtObj.mjOBJ_GEOM: self.ngeom,
mujoco.mjtObj.mjOBJ_SITE: self.nsite,
mujoco.mjtObj.mjOBJ_ACTUATOR: self.nu,
}
adr_map = {
mujoco.mjtObj.mjOBJ_BODY: self.name_bodyadr,
mujoco.mjtObj.mjOBJ_JOINT: self.name_jntadr,
mujoco.mjtObj.mjOBJ_GEOM: self.name_geomadr,
mujoco.mjtObj.mjOBJ_SITE: self.name_siteadr,
mujoco.mjtObj.mjOBJ_ACTUATOR: self.name_actuatoradr,
}
num = num_map.get(typ, 0)
adr = adr_map.get(typ)
if adr is None or i < 0 or i >= num:
return None
name = self.names[adr[i] :].decode('utf-8').split('\x00', 1)[0]
return name or None

def body(self, name: str) -> BodyView:
"""Gets a body by name.

Args:
name: The name of the body.

Returns:
A BodyView with the body's id and name.

Raises:
KeyError: If no body with the given name exists.
"""
id_ = self._name2id(mujoco.mjtObj.mjOBJ_BODY, name)
if id_ < 0:
raise KeyError(f"body '{name}' not found")
return BodyView(id=id_, name=name)

def joint(self, name: str) -> JointView:
"""Gets a joint by name.

Args:
name: The name of the joint.

Returns:
A JointView with the joint's id and name.

Raises:
KeyError: If no joint with the given name exists.
"""
id_ = self._name2id(mujoco.mjtObj.mjOBJ_JOINT, name)
if id_ < 0:
raise KeyError(f"joint '{name}' not found")
return JointView(id=id_, name=name)

def geom(self, name: str) -> GeomView:
"""Gets a geom by name.

Args:
name: The name of the geom.

Returns:
A GeomView with the geom's id and name.

Raises:
KeyError: If no geom with the given name exists.
"""
id_ = self._name2id(mujoco.mjtObj.mjOBJ_GEOM, name)
if id_ < 0:
raise KeyError(f"geom '{name}' not found")
return GeomView(id=id_, name=name)

def site(self, name: str) -> SiteView:
"""Gets a site by name.

Args:
name: The name of the site.

Returns:
A SiteView with the site's id and name.

Raises:
KeyError: If no site with the given name exists.
"""
id_ = self._name2id(mujoco.mjtObj.mjOBJ_SITE, name)
if id_ < 0:
raise KeyError(f"site '{name}' not found")
return SiteView(id=id_, name=name)

def actuator(self, name: str) -> ActuatorView:
"""Gets an actuator by name.

Args:
name: The name of the actuator.

Returns:
An ActuatorView with the actuator's id and name.

Raises:
KeyError: If no actuator with the given name exists.
"""
id_ = self._name2id(mujoco.mjtObj.mjOBJ_ACTUATOR, name)
if id_ < 0:
raise KeyError(f"actuator '{name}' not found")
return ActuatorView(id=id_, name=name)


class Contact(PyTreeNode):
"""Result of collision detection functions.
Expand Down