Skip to content

Commit 8541128

Browse files
thowellcopybara-github
authored andcommitted
Update _get_data_into_warp. Fixes #3109
PiperOrigin-RevId: 871790690 Change-Id: I225579f7559ca5c1de92af6606d02ff763f9f4d1
1 parent 6b197e1 commit 8541128

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

mjx/mujoco/mjx/_src/io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1486,10 +1486,11 @@ def _get_data_into_warp(
14861486
if field.name in (
14871487
'actuator_moment',
14881488
'contact',
1489-
'efc_J',
14901489
'qM',
14911490
'qLD',
14921491
'qLDiagInv',
1492+
'ten_J',
1493+
'flexedge_J',
14931494
):
14941495
continue
14951496
if field.name.startswith('efc_'):

mjx/mujoco/mjx/_src/io_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,22 @@ def test_get_data_into(self, impl):
744744
self.assertEqual(d_2.contact.frame.shape, (1, 9))
745745
np.testing.assert_allclose(d_2.contact.frame, d.contact.frame)
746746

747+
def test_get_data_into_warp(self):
748+
"""Test get_data_into for impl='warp'."""
749+
750+
# TODO(taylorhowell): After put_data supports impl='warp' update test above
751+
# and remove this test.
752+
if not mjxw.WARP_INSTALLED:
753+
self.skipTest('Warp is not installed.')
754+
if not mjx_io.has_cuda_gpu_device():
755+
self.skipTest('No CUDA GPU device.')
756+
757+
m = mujoco.MjModel.from_xml_string('<mujoco></mujoco>')
758+
d = mujoco.MjData(m)
759+
mx = mjx.put_model(m, impl='warp')
760+
dx = mjx.make_data(m, impl='warp')
761+
mjx.get_data_into(d, mx, dx)
762+
747763
@parameterized.parameters('jax', 'c')
748764
def test_get_data_into_wrong_shape(self, impl):
749765
"""Tests that get_data_into throwsif input and output shapes don't match."""

0 commit comments

Comments
 (0)