Skip to content

Commit 41c2918

Browse files
quaglacopybara-github
authored andcommitted
Add batch support to MJX data bind.
PiperOrigin-RevId: 729033713 Change-Id: I331317a54a7882dc67bf329da8aeaff7246754e4
1 parent 19624ae commit 41c2918

File tree

3 files changed

+79
-7
lines changed

3 files changed

+79
-7
lines changed

mjx/mujoco/mjx/_src/support.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax
2020
from jax import numpy as jp
2121
import mujoco
22+
from mujoco.introspect import mjxmacro
2223
from mujoco.mjx._src import math
2324
from mujoco.mjx._src import scan
2425
# pylint: disable=g-importing-member
@@ -368,8 +369,17 @@ def __init__(self, model: Model, specs: Sequence[mujoco.MjStruct]):
368369
else:
369370
self.id = ids
370371

372+
def _slice(self, name: str, idx: Union[int, slice, Sequence[int]]):
373+
_, expected_dim = mjxmacro.MJMODEL[name]
374+
var = getattr(self.model, name)
375+
if expected_dim == '1':
376+
return var[..., idx]
377+
elif expected_dim == '9':
378+
return var[..., idx, :, :]
379+
return var[..., idx, :]
380+
371381
def __getattr__(self, name: str):
372-
return getattr(self.model, self.prefix + name)[self.id, ...]
382+
return self._slice(self.prefix + name, self.id)
373383

374384

375385
def _bind_model(
@@ -453,6 +463,15 @@ def __getname(self, name: str):
453463
else:
454464
return self.prefix + name
455465

466+
def _slice(self, name: str, idx: Union[int, slice, Sequence[int]]):
467+
_, expected_dim = mjxmacro.MJDATA[name]
468+
var = getattr(self.data, name)
469+
if expected_dim == '1':
470+
return var[..., idx]
471+
elif expected_dim == '9':
472+
return var[..., idx, :, :]
473+
return var[..., idx, :]
474+
456475
def __getattr__(self, name: str):
457476
if name in ('sensordata', 'qpos', 'qvel', 'qacc'):
458477
adr = num = 0
@@ -471,12 +490,12 @@ def __getattr__(self, name: str):
471490
idx = []
472491
for a, n in zip(adr, num):
473492
idx.extend(a + j for j in range(n))
474-
return getattr(self.data, self.__getname(name))[idx, ...]
493+
return self._slice(self.__getname(name), idx)
475494
elif num > 1:
476-
return getattr(self.data, self.__getname(name))[adr : adr + num, ...]
495+
return self._slice(self.__getname(name), slice(adr, adr + num))
477496
else:
478-
return getattr(self.data, self.__getname(name))[adr, ...]
479-
return getattr(self.data, self.__getname(name))[self.id, ...]
497+
return self._slice(self.__getname(name), adr)
498+
return self._slice(self.__getname(name), self.id)
480499

481500
def set(self, name: str, value: jax.Array) -> Data:
482501
"""Set the value of an array in an MJX Data."""

mjx/mujoco/mjx/_src/support_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,11 +305,11 @@ def test_bind(self):
305305
):
306306
print(dx.bind(mx, s.geoms).ctrl)
307307
with self.assertRaises(
308-
AttributeError, msg='ctrl is not available for this type'
308+
KeyError, msg='actuator_actuator_ctrl'
309309
):
310310
print(dx.bind(mx, s.actuators).actuator_ctrl)
311311
with self.assertRaises(
312-
AttributeError, msg='ctrl is not available for this type'
312+
AttributeError, msg='actuator_actuator_ctrl'
313313
):
314314
print(dx.bind(mx, s.actuators).set('actuator_ctrl', [1, 2, 3]))
315315
with self.assertRaises(
@@ -323,6 +323,16 @@ def test_bind(self):
323323
s.geoms[0].name = 'invalid_geom_name'
324324
print(mx.bind(s.geoms).pos)
325325

326+
# test batched data
327+
batch_size = 16
328+
ds = [d for _ in range(batch_size)]
329+
vdx = jax.vmap(lambda xpos: dx.replace(xpos=xpos))(
330+
jp.array([d.xpos for d in ds], device=jax.devices('cpu')[0]))
331+
for i in range(m.nbody):
332+
np.testing.assert_array_equal(
333+
vdx.bind(mx, s.bodies[i]).xpos, [d.xpos[i, :]] * batch_size
334+
)
335+
326336
_CONTACTS = """
327337
<mujoco>
328338
<worldbody>
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2025 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Generate X macros for Mujoco structs."""
16+
17+
from . import structs
18+
19+
MJMODEL_S = structs.STRUCTS['mjModel']
20+
MJDATA_S = structs.STRUCTS['mjData']
21+
22+
MJMODEL = dict()
23+
MJDATA = dict()
24+
25+
for field in MJMODEL_S.fields:
26+
if not isinstance(field, structs.StructFieldDecl):
27+
continue
28+
if field.array_extent is None:
29+
continue
30+
if len(field.array_extent) == 1:
31+
MJMODEL[field.name] = (field.array_extent[0], '1')
32+
else:
33+
MJMODEL[field.name] = (field.array_extent[0], str(field.array_extent[1]))
34+
35+
for field in MJDATA_S.fields:
36+
if not isinstance(field, structs.StructFieldDecl):
37+
continue
38+
if field.array_extent is None:
39+
continue
40+
if len(field.array_extent) == 1:
41+
MJDATA[field.name] = (field.array_extent[0], '1')
42+
else:
43+
MJDATA[field.name] = (field.array_extent[0], str(field.array_extent[1]))

0 commit comments

Comments
 (0)