Skip to content

Commit 5b924fe

Browse files
quaglacopybara-github
authored andcommitted
Add support for binding to arrays of mjs element.
Fixes #2402. PiperOrigin-RevId: 726449592 Change-Id: I0cde44889837a71a222b20cf91fa56568b4af689
1 parent 1b4258d commit 5b924fe

File tree

3 files changed

+101
-3
lines changed

3 files changed

+101
-3
lines changed

python/mujoco/__init__.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import platform
2121
import subprocess
22-
from typing import IO, Union
22+
from typing import Any, IO, Union, Sequence
2323
from typing_extensions import TypeAlias
2424
import warnings
2525
import zipfile
@@ -50,6 +50,7 @@
5050
'native, arm64 build of Python.')
5151

5252
from mujoco import _specs
53+
from mujoco import _structs
5354
from mujoco._callbacks import *
5455
from mujoco._constants import *
5556
from mujoco._enums import *
@@ -88,6 +89,7 @@
8889
_specs.MjsPlugin,
8990
]
9091

92+
9193
def to_zip(spec: _specs.MjSpec, file: Union[str, IO[bytes]]) -> None:
9294
"""Converts a spec to a zip file.
9395
@@ -106,7 +108,68 @@ def to_zip(spec: _specs.MjSpec, file: Union[str, IO[bytes]]) -> None:
106108
zip_info = zipfile.ZipInfo(os.path.join(spec.modelname, filename))
107109
zip_file.writestr(zip_info, contents)
108110

111+
112+
class _MjBindModel:
113+
def __init__(self, elements: Sequence[Any]):
114+
self.elements = elements
115+
116+
def __getattr__(self, key: str):
117+
items = []
118+
for e in self.elements:
119+
items.extend(getattr(e, key))
120+
return items
121+
122+
123+
class _MjBindData:
124+
def __init__(self, elements: Sequence[Any]):
125+
self.elements = elements
126+
127+
def __getattr__(self, key: str):
128+
items = []
129+
for e in self.elements:
130+
items.extend(getattr(e, key))
131+
return items
132+
133+
134+
def _bind_model(
135+
model: _structs.MjModel, specs: Union[Sequence[MjStruct], MjStruct]
136+
):
137+
"""Bind a Mujoco spec to a mjModel.
138+
139+
Args:
140+
model: The mjModel to bind to.
141+
specs: The mjSpec elements to use for binding, can be a single element or a
142+
sequence.
143+
Returns:
144+
A MjModelGroupedViews object or a list of the same type.
145+
"""
146+
if isinstance(specs, Sequence):
147+
return _MjBindModel([model.bind_scalar(s) for s in specs])
148+
else:
149+
return model.bind_scalar(specs)
150+
151+
152+
def _bind_data(
153+
data: _structs.MjData, specs: Union[Sequence[MjStruct], MjStruct]
154+
):
155+
"""Bind a Mujoco spec to a mjData.
156+
157+
Args:
158+
data: The mjData to bind to.
159+
specs: The mjSpec elements to use for binding, can be a single element or a
160+
sequence.
161+
Returns:
162+
A MjDataGroupedViews object or a list of the same type.
163+
"""
164+
if isinstance(specs, Sequence):
165+
return _MjBindData([data.bind_scalar(s) for s in specs])
166+
else:
167+
return data.bind_scalar(specs)
168+
169+
109170
_specs.MjSpec.to_zip = to_zip
171+
_structs.MjData.bind = _bind_data
172+
_structs.MjModel.bind = _bind_model
110173

111174
HEADERS_DIR = os.path.join(os.path.dirname(__file__), 'include/mujoco')
112175
PLUGINS_DIR = os.path.join(os.path.dirname(__file__), 'plugin')

python/mujoco/specs_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,5 +1100,40 @@ def test_attach_to_frame(self):
11001100
with self.assertRaisesRegex(ValueError, 'Frame not found.'):
11011101
parent.attach(child4, frame='invalid_frame', prefix='child3-')
11021102

1103+
def test_bind(self):
1104+
spec = mujoco.MjSpec.from_string("""
1105+
<mujoco>
1106+
<worldbody>
1107+
<body name="main">
1108+
<geom name="main" size="0.15 0.15 0.15" mass="1" type="box"/>
1109+
<freejoint/>
1110+
<body name="box">
1111+
<joint name="box" type="hinge" range="-1 +1"/>
1112+
<geom name="box" size="0.15 0.15 0.15" mass="1" type="box"/>
1113+
</body>
1114+
<body name="sphere">
1115+
<joint name="sphere" type="hinge" range="-1 +1"/>
1116+
<geom name="sphere" size="0.15 0.15 0.15" mass="1" type="box"/>
1117+
</body>
1118+
</body>
1119+
</worldbody>
1120+
</mujoco>
1121+
""")
1122+
joint_box = spec.joint('box')
1123+
joint_sphere = spec.joint('sphere')
1124+
joints = [joint_box, joint_sphere]
1125+
mj_model = spec.compile()
1126+
mj_data = mujoco.MjData(mj_model)
1127+
np.testing.assert_array_equal(mj_data.bind(joint_box).qpos, 0)
1128+
np.testing.assert_array_equal(mj_model.bind(joint_box).qposadr, 7)
1129+
np.testing.assert_array_equal(mj_data.bind(joints).qpos, [0, 0])
1130+
np.testing.assert_array_equal(mj_model.bind(joints).qposadr, [7, 8])
1131+
np.testing.assert_array_equal(mj_data.bind([]).qpos, [])
1132+
np.testing.assert_array_equal(mj_model.bind([]).qposadr, [])
1133+
with self.assertRaisesRegex(
1134+
AttributeError, "object has no attribute 'invalid'"
1135+
):
1136+
print(mj_model.bind(joints).invalid)
1137+
11031138
if __name__ == '__main__':
11041139
absltest.main()

python/mujoco/structs.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,7 +1720,7 @@ This is useful for example when the MJB is not available as a file on disk.)"));
17201720

17211721
#define XGROUP(spectype, field) \
17221722
mjModel.def( \
1723-
"bind", \
1723+
"bind_scalar", \
17241724
[](MjModelWrapper& m, spectype& spec) -> auto& { \
17251725
return m.indexer().field##_by_name(mjs_getString(spec.name)); \
17261726
}, \
@@ -2066,7 +2066,7 @@ This is useful for example when the MJB is not available as a file on disk.)"));
20662066

20672067
#define XGROUP(spectype, field) \
20682068
mjData.def( \
2069-
"bind", \
2069+
"bind_scalar", \
20702070
[](MjDataWrapper& d, spectype& spec) -> auto& { \
20712071
return d.indexer().field##_by_name(mjs_getString(spec.name)); \
20722072
}, \

0 commit comments

Comments
 (0)