File tree Expand file tree Collapse file tree 2 files changed +18
-1
lines changed
Expand file tree Collapse file tree 2 files changed +18
-1
lines changed Original file line number Diff line number Diff line change @@ -1486,10 +1486,11 @@ def _get_data_into_warp(
14861486 if field .name in (
14871487 'actuator_moment' ,
14881488 'contact' ,
1489- 'efc_J' ,
14901489 'qM' ,
14911490 'qLD' ,
14921491 'qLDiagInv' ,
1492+ 'ten_J' ,
1493+ 'flexedge_J' ,
14931494 ):
14941495 continue
14951496 if field .name .startswith ('efc_' ):
Original file line number Diff line number Diff line change @@ -744,6 +744,22 @@ def test_get_data_into(self, impl):
744744 self .assertEqual (d_2 .contact .frame .shape , (1 , 9 ))
745745 np .testing .assert_allclose (d_2 .contact .frame , d .contact .frame )
746746
747+ def test_get_data_into_warp (self ):
748+ """Test get_data_into for impl='warp'."""
749+
750+ # TODO(taylorhowell): After put_data supports impl='warp' update test above
751+ # and remove this test.
752+ if not mjxw .WARP_INSTALLED :
753+ self .skipTest ('Warp is not installed.' )
754+ if not mjx_io .has_cuda_gpu_device ():
755+ self .skipTest ('No CUDA GPU device.' )
756+
757+ m = mujoco .MjModel .from_xml_string ('<mujoco></mujoco>' )
758+ d = mujoco .MjData (m )
759+ mx = mjx .put_model (m , impl = 'warp' )
760+ dx = mjx .make_data (m , impl = 'warp' )
761+ mjx .get_data_into (d , mx , dx )
762+
747763 @parameterized .parameters ('jax' , 'c' )
748764 def test_get_data_into_wrong_shape (self , impl ):
749765 """Tests that get_data_into throwsif input and output shapes don't match."""
You can’t perform that action at this time.
0 commit comments