Skip to content

Commit 6f0bd0a

Browse files
quaglacopybara-github
authored andcommitted
Fix setting multidimensional arrays in MJX bind.
PiperOrigin-RevId: 726956380 Change-Id: I51ab91b73a3a34938ac7bcd03e9691fd634cb296
1 parent 5b924fe commit 6f0bd0a

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

mjx/mujoco/mjx/_src/support.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ def set(self, name: str, value: jax.Array) -> Data:
476476
if name == 'sensordata':
477477
raise AttributeError('sensordata is readonly')
478478
array = getattr(self.data, self.__getname(name))
479+
dim = 1 if len(array.shape) == 1 else array.shape[-1]
479480
try:
480481
iter(value)
481482
except TypeError:
@@ -490,10 +491,10 @@ def set(self, name: str, value: jax.Array) -> Data:
490491
num = sum((typ == jt) * jt.dof_width() for jt in JointType)
491492
elif isinstance(self.id, list):
492493
adr = self.id
493-
num = [1 for _ in range(len(self.id))]
494+
num = [dim for _ in range(len(self.id))]
494495
else:
495496
adr = [self.id]
496-
num = [1]
497+
num = [dim]
497498
i = 0
498499
for a, n in zip(adr, num):
499500
array = array.at[a: a + n].set(value[i: i + n])

mjx/mujoco/mjx/_src/support_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ def test_bind(self):
207207
np.testing.assert_array_equal(
208208
dx.bind(mx, s.bodies[i]).xpos, d.xpos[i, :]
209209
)
210+
np.testing.assert_array_equal(
211+
dx.bind(mx, s.bodies[i]).xfrc_applied, d.xfrc_applied[i, :]
212+
)
210213

211214
np.testing.assert_array_equal(mx.bind(s.geoms).size, m.geom_size)
212215
np.testing.assert_array_equal(dx.bind(mx, s.geoms).xpos, d.geom_xpos)
@@ -269,6 +272,11 @@ def test_bind(self):
269272
np.testing.assert_array_equal(dx6.bind(mx, s.joints).qpos, qpos_desired)
270273
np.testing.assert_array_almost_equal(dx.bind(mx, s.joints).qpos, d.qpos)
271274

275+
dx7 = dx.bind(mx, s.bodies[1]).set('xfrc_applied', [1, 2, 3, 4, 5, 6])
276+
np.testing.assert_array_equal(
277+
dx7.bind(mx, s.bodies[1]).xfrc_applied, [1, 2, 3, 4, 5, 6]
278+
)
279+
272280
# test invalid name
273281
with self.assertRaises(AttributeError):
274282
print(dx.bind(mx, s.geoms).ctrl)

0 commit comments

Comments
 (0)