Skip to content

Commit 0be7e6a

Browse files
btabacopybara-github
authored andcommitted
Fix #2306.
PiperOrigin-RevId: 712575320 Change-Id: Ia74ef0b31b4e7098647340760572a1f6368eab64
1 parent f607d95 commit 0be7e6a

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

mjx/mujoco/mjx/_src/io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,4 +681,5 @@ def put_data(
681681
# copy because device_put is async:
682682
data = types.Data(**{k: copy.copy(v) for k, v in fields.items()})
683683

684-
return jax.device_put(data, device=device)
684+
data = jax.device_put(data, device=device)
685+
return _strip_weak_type(data)

mjx/mujoco/mjx/_src/io_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,15 @@ def test_put_data(self):
310310
np.testing.assert_allclose(dx.cvel, d.cvel)
311311
np.testing.assert_allclose(dx.cdof_dot, d.cdof_dot)
312312

313+
# check that there are no weak types
314+
self.assertFalse(
315+
any(
316+
jax.tree_util.tree_flatten(
317+
jax.tree_util.tree_map(lambda x: x.weak_type, dx)
318+
)[0]
319+
)
320+
)
321+
313322
# check that qM is transformed properly
314323
qm = np.zeros((m.nv, m.nv), dtype=np.float64)
315324
mujoco.mj_fullM(m, qm, d.qM)

0 commit comments

Comments
 (0)