@@ -1471,12 +1471,29 @@ def make_env():
1471
1471
"transformed_in,transformed_out" , [[True , True ], [False , False ]]
1472
1472
) # 1226: effociency
1473
1473
@pytest .mark .parametrize ("static_seed" , [False , True ])
1474
+ @pytest .mark .parametrize ("penv_device" , ["cpu" , None ])
1475
+ @pytest .mark .parametrize ("env_device" , ["cpu" , None ])
1476
+ @pytest .mark .parametrize ("bwad" , [True , False ])
1474
1477
def test_parallel_env_seed (
1475
- self , env_name , frame_skip , transformed_in , transformed_out , static_seed
1478
+ self ,
1479
+ env_name ,
1480
+ frame_skip ,
1481
+ transformed_in ,
1482
+ transformed_out ,
1483
+ static_seed ,
1484
+ penv_device ,
1485
+ env_device ,
1486
+ bwad ,
1476
1487
):
1477
1488
env_name = env_name ()
1478
1489
env_parallel , env_serial , _ , _ = _make_envs (
1479
- env_name , frame_skip , transformed_in , transformed_out , 5
1490
+ env_name ,
1491
+ frame_skip ,
1492
+ transformed_in ,
1493
+ transformed_out ,
1494
+ 5 ,
1495
+ p_env_device = penv_device ,
1496
+ env_device = env_device ,
1480
1497
)
1481
1498
try :
1482
1499
out_seed_serial = env_serial .set_seed (0 , static_seed = static_seed )
@@ -1486,7 +1503,10 @@ def test_parallel_env_seed(
1486
1503
torch .manual_seed (0 )
1487
1504
1488
1505
td_serial = env_serial .rollout (
1489
- max_steps = 10 , auto_reset = False , tensordict = td0_serial
1506
+ max_steps = 10 ,
1507
+ auto_reset = False ,
1508
+ tensordict = td0_serial ,
1509
+ break_when_any_done = bwad ,
1490
1510
).contiguous ()
1491
1511
key = "pixels" if "pixels" in td_serial .keys () else "observation"
1492
1512
torch .testing .assert_close (
@@ -1501,7 +1521,10 @@ def test_parallel_env_seed(
1501
1521
torch .manual_seed (0 )
1502
1522
assert out_seed_parallel == out_seed_serial
1503
1523
td_parallel = env_parallel .rollout (
1504
- max_steps = 10 , auto_reset = False , tensordict = td0_parallel
1524
+ max_steps = 10 ,
1525
+ auto_reset = False ,
1526
+ tensordict = td0_parallel ,
1527
+ break_when_any_done = bwad ,
1505
1528
).contiguous ()
1506
1529
torch .testing .assert_close (
1507
1530
td_parallel [:, :- 1 ].get (("next" , key )), td_parallel [:, 1 :].get (key )
@@ -1677,7 +1700,7 @@ def test_parallel_env_device(
1677
1700
frame_skip ,
1678
1701
transformed_in = transformed_in ,
1679
1702
transformed_out = transformed_out ,
1680
- device = device ,
1703
+ env_device = device ,
1681
1704
N = N ,
1682
1705
local_mp_ctx = "spawn" ,
1683
1706
)
0 commit comments