File tree Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -681,4 +681,5 @@ def put_data(
681
681
# copy because device_put is async:
682
682
data = types .Data (** {k : copy .copy (v ) for k , v in fields .items ()})
683
683
684
- return jax .device_put (data , device = device )
684
+ data = jax .device_put (data , device = device )
685
+ return _strip_weak_type (data )
Original file line number Diff line number Diff line change @@ -310,6 +310,15 @@ def test_put_data(self):
310
310
np .testing .assert_allclose (dx .cvel , d .cvel )
311
311
np .testing .assert_allclose (dx .cdof_dot , d .cdof_dot )
312
312
313
+ # check that there are no weak types
314
+ self .assertFalse (
315
+ any (
316
+ jax .tree_util .tree_flatten (
317
+ jax .tree_util .tree_map (lambda x : x .weak_type , dx )
318
+ )[0 ]
319
+ )
320
+ )
321
+
313
322
# check that qM is transformed properly
314
323
qm = np .zeros ((m .nv , m .nv ), dtype = np .float64 )
315
324
mujoco .mj_fullM (m , qm , d .qM )
You can’t perform that action at this time.
0 commit comments