@@ -318,6 +318,21 @@ def _make_spec( # noqa: F811
318
318
shape = batch_size ,
319
319
)
320
320
321
+ @implement_for ("gymnasium" , "1.1.0" )
322
+ def _make_spec ( # noqa: F811
323
+ self , batch_size , cat , cat_shape , multicat , multicat_shape
324
+ ):
325
+ return Composite (
326
+ a = Unbounded (shape = (* batch_size , 1 )),
327
+ b = Composite (c = cat (5 , shape = cat_shape , dtype = torch .int64 ), shape = batch_size ),
328
+ d = cat (5 , shape = cat_shape , dtype = torch .int64 ),
329
+ e = multicat ([2 , 3 ], shape = (* batch_size , multicat_shape ), dtype = torch .int64 ),
330
+ f = Bounded (- 3 , 4 , shape = (* batch_size , 1 )),
331
+ g = UnboundedDiscreteTensorSpec (shape = (* batch_size , 1 ), dtype = torch .long ),
332
+ h = Binary (n = 5 , shape = (* batch_size , 5 )),
333
+ shape = batch_size ,
334
+ )
335
+
321
336
@pytest .mark .parametrize ("categorical" , [True , False ])
322
337
def test_gym_spec_cast (self , categorical ):
323
338
batch_size = [3 , 4 ]
@@ -379,10 +394,17 @@ def test_gym_spec_cast_tuple_sequential(self, order):
379
394
torchrl_logger .info ("Sequence not available in gym" )
380
395
return
381
396
382
- # @pytest.mark.parametrize("order", ["seq_tuple", "tuple_seq"])
397
+ @pytest .mark .parametrize ("order" , ["tuple_seq" ])
398
+ @implement_for ("gymnasium" , "1.1.0" )
399
+ def test_gym_spec_cast_tuple_sequential (self , order ): # noqa: F811
400
+ self ._test_gym_spec_cast_tuple_sequential (order )
401
+
383
402
@pytest .mark .parametrize ("order" , ["tuple_seq" ])
384
403
@implement_for ("gymnasium" , None , "1.0.0" )
385
404
def test_gym_spec_cast_tuple_sequential (self , order ): # noqa: F811
405
+ self ._test_gym_spec_cast_tuple_sequential (order )
406
+
407
+ def _test_gym_spec_cast_tuple_sequential (self , order ): # noqa: F811
386
408
with set_gym_backend ("gymnasium" ):
387
409
if order == "seq_tuple" :
388
410
# Requires nested tensors to be created along dim=1, disabling
@@ -974,8 +996,15 @@ def info_reader(info, tensordict):
974
996
finally :
975
997
set_gym_backend (gb ).set ()
976
998
977
- @implement_for ("gymnasium" , None , "1.0 .0" )
999
+ @implement_for ("gymnasium" , "1.1 .0" )
978
1000
def test_one_hot_and_categorical (self ):
1001
+ self ._test_one_hot_and_categorical ()
1002
+
1003
+ @implement_for ("gymnasium" , None , "1.0.0" )
1004
+ def test_one_hot_and_categorical (self ): # noqa
1005
+ self ._test_one_hot_and_categorical ()
1006
+
1007
+ def _test_one_hot_and_categorical (self ):
979
1008
# tests that one-hot and categorical work ok when an integer is expected as action
980
1009
cliff_walking = GymEnv ("CliffWalking-v0" , categorical_action_encoding = True )
981
1010
cliff_walking .rollout (10 )
@@ -993,14 +1022,27 @@ def test_one_hot_and_categorical(self): # noqa: F811
993
1022
# versions.
994
1023
return
995
1024
996
- @implement_for ("gymnasium" , None , "1.0 .0" )
1025
+ @implement_for ("gymnasium" , "1.1 .0" )
997
1026
@pytest .mark .parametrize (
998
1027
"envname" ,
999
1028
["HalfCheetah-v4" , "CartPole-v1" , "ALE/Pong-v5" ]
1000
1029
+ (["FetchReach-v2" ] if _has_gym_robotics else []),
1001
1030
)
1002
1031
@pytest .mark .flaky (reruns = 5 , reruns_delay = 1 )
1003
1032
def test_vecenvs_wrapper (self , envname ):
1033
+ self ._test_vecenvs_wrapper (envname )
1034
+
1035
+ @implement_for ("gymnasium" , None , "1.0.0" )
1036
+ @pytest .mark .parametrize (
1037
+ "envname" ,
1038
+ ["HalfCheetah-v4" , "CartPole-v1" , "ALE/Pong-v5" ]
1039
+ + (["FetchReach-v2" ] if _has_gym_robotics else []),
1040
+ )
1041
+ @pytest .mark .flaky (reruns = 5 , reruns_delay = 1 )
1042
+ def test_vecenvs_wrapper (self , envname ): # noqa
1043
+ self ._test_vecenvs_wrapper (envname )
1044
+
1045
+ def _test_vecenvs_wrapper (self , envname ):
1004
1046
import gymnasium
1005
1047
1006
1048
# we can't use parametrize with implement_for
@@ -1019,7 +1061,7 @@ def test_vecenvs_wrapper(self, envname):
1019
1061
assert env .batch_size == torch .Size ([2 ])
1020
1062
check_env_specs (env )
1021
1063
1022
- @implement_for ("gymnasium" , None , "1.0 .0" )
1064
+ @implement_for ("gymnasium" , "1.1 .0" )
1023
1065
# this env has Dict-based observation which is a nice thing to test
1024
1066
@pytest .mark .parametrize (
1025
1067
"envname" ,
@@ -1028,6 +1070,21 @@ def test_vecenvs_wrapper(self, envname):
1028
1070
)
1029
1071
@pytest .mark .flaky (reruns = 5 , reruns_delay = 1 )
1030
1072
def test_vecenvs_env (self , envname ):
1073
+ self ._test_vecenvs_env (envname )
1074
+
1075
+ @implement_for ("gymnasium" , None , "1.0.0" )
1076
+ # this env has Dict-based observation which is a nice thing to test
1077
+ @pytest .mark .parametrize (
1078
+ "envname" ,
1079
+ ["HalfCheetah-v4" , "CartPole-v1" , "ALE/Pong-v5" ]
1080
+ + (["FetchReach-v2" ] if _has_gym_robotics else []),
1081
+ )
1082
+ @pytest .mark .flaky (reruns = 5 , reruns_delay = 1 )
1083
+ def test_vecenvs_env (self , envname ): # noqa
1084
+ self ._test_vecenvs_env (envname )
1085
+
1086
+ def _test_vecenvs_env (self , envname ):
1087
+
1031
1088
gb = gym_backend ()
1032
1089
try :
1033
1090
with set_gym_backend ("gymnasium" ):
@@ -1181,9 +1238,17 @@ def test_gym_output_num(self, wrapper): # noqa: F811
1181
1238
finally :
1182
1239
set_gym_backend (gym ).set ()
1183
1240
1241
+ @implement_for ("gymnasium" , "1.1.0" )
1242
+ @pytest .mark .parametrize ("wrapper" , [True , False ])
1243
+ def test_gym_output_num (self , wrapper ): # noqa: F811
1244
+ self ._test_gym_output_num (wrapper )
1245
+
1184
1246
@implement_for ("gymnasium" , None , "1.0.0" )
1185
1247
@pytest .mark .parametrize ("wrapper" , [True , False ])
1186
1248
def test_gym_output_num (self , wrapper ): # noqa: F811
1249
+ self ._test_gym_output_num (wrapper )
1250
+
1251
+ def _test_gym_output_num (self , wrapper ): # noqa: F811
1187
1252
# gym has 5 outputs, with truncation
1188
1253
gym = gym_backend ()
1189
1254
try :
@@ -1284,8 +1349,15 @@ def test_vecenvs_nan(self): # noqa: F811
1284
1349
del c
1285
1350
return
1286
1351
1352
+ @implement_for ("gymnasium" , "1.1.0" )
1353
+ def test_vecenvs_nan (self ): # noqa: F811
1354
+ self ._test_vecenvs_nan ()
1355
+
1287
1356
@implement_for ("gymnasium" , None , "1.0.0" )
1288
1357
def test_vecenvs_nan (self ): # noqa: F811
1358
+ self ._test_vecenvs_nan ()
1359
+
1360
+ def _test_vecenvs_nan (self ): # noqa: F811
1289
1361
# new versions of gym must never return nan for next values when there is a done state
1290
1362
torch .manual_seed (0 )
1291
1363
env = GymEnv ("CartPole-v1" , num_envs = 2 )
@@ -1352,8 +1424,118 @@ def step(self, action):
1352
1424
1353
1425
return CustomEnv (** kwargs )
1354
1426
1427
+ def counting_env (self ):
1428
+ import gymnasium as gym
1429
+ from gymnasium import Env
1430
+
1431
+ class CountingEnvRandomReset (Env ):
1432
+ def __init__ (self , i = 0 ):
1433
+ self .counter = 1
1434
+ self .i = i
1435
+ self .observation_space = gym .spaces .Box (- np .inf , np .inf , shape = (1 ,))
1436
+ self .action_space = gym .spaces .Box (- np .inf , np .inf , shape = (1 ,))
1437
+ self .rng = np .random .RandomState (0 )
1438
+
1439
+ def step (self , action ):
1440
+ self .counter += 1
1441
+ done = bool (self .rng .random () < 0.05 )
1442
+ return (
1443
+ np .asarray (
1444
+ [
1445
+ self .counter ,
1446
+ ]
1447
+ ),
1448
+ 0 ,
1449
+ done ,
1450
+ done ,
1451
+ {},
1452
+ )
1453
+
1454
+ def reset (
1455
+ self ,
1456
+ * ,
1457
+ seed : int | None = None ,
1458
+ options = None ,
1459
+ ):
1460
+ self .counter = 1
1461
+ if seed is not None :
1462
+ self .rng = np .random .RandomState (seed )
1463
+ return (
1464
+ np .asarray (
1465
+ [
1466
+ self .counter ,
1467
+ ]
1468
+ ),
1469
+ {},
1470
+ )
1471
+
1472
+ yield CountingEnvRandomReset
1473
+
1474
+ @implement_for ("gym" )
1475
+ def test_gymnasium_autoreset (self , venv ):
1476
+ return
1477
+
1478
+ @implement_for ("gymnasium" , None , "1.1.0" )
1479
+ def test_gymnasium_autoreset (self , venv ): # noqa
1480
+ return
1481
+
1482
+ @implement_for ("gymnasium" , "1.1.0" )
1483
+ @pytest .mark .parametrize ("venv" , ["sync" , "async" ])
1484
+ def test_gymnasium_autoreset (self , venv ): # noqa
1485
+ import gymnasium as gym
1486
+
1487
+ counting_env = self .counting_env ()
1488
+ if venv == "sync" :
1489
+ venv = gym .vector .SyncVectorEnv
1490
+ else :
1491
+ venv = gym .vector .AsyncVectorEnv
1492
+ envs0 = venv (
1493
+ [lambda i = i : counting_env (i ) for i in range (2 )],
1494
+ autoreset_mode = gym .vector .AutoresetMode .DISABLED ,
1495
+ )
1496
+ env = GymWrapper (envs0 )
1497
+ envs0 .reset (seed = 0 )
1498
+ torch .manual_seed (0 )
1499
+ r0 = env .rollout (20 , break_when_any_done = False )
1500
+ envs1 = venv (
1501
+ [lambda i = i : counting_env (i ) for i in range (2 )],
1502
+ autoreset_mode = gym .vector .AutoresetMode .SAME_STEP ,
1503
+ )
1504
+ env = GymWrapper (envs1 )
1505
+ envs1 .reset (seed = 0 )
1506
+ # env.set_seed(0)
1507
+ torch .manual_seed (0 )
1508
+ r1 = []
1509
+ t_ = env .reset ()
1510
+ for s in r0 .unbind (- 1 ):
1511
+ t_ .set ("action" , s ["action" ])
1512
+ t , t_ = env .step_and_maybe_reset (t_ )
1513
+ r1 .append (t )
1514
+ r1 = torch .stack (r1 , - 1 )
1515
+ torch .testing .assert_close (r0 ["observation" ], r1 ["observation" ])
1516
+ torch .testing .assert_close (r0 ["next" , "observation" ], r1 ["next" , "observation" ])
1517
+ torch .testing .assert_close (r0 ["next" , "done" ], r1 ["next" , "done" ])
1518
+
1519
+ @implement_for ("gym" )
1355
1520
@pytest .mark .parametrize ("heterogeneous" , [False , True ])
1356
1521
def test_resetting_strategies (self , heterogeneous ):
1522
+ return
1523
+
1524
+ @implement_for ("gymnasium" , None , "1.0.0" )
1525
+ @pytest .mark .parametrize ("heterogeneous" , [False , True ])
1526
+ def test_resetting_strategies (self , heterogeneous ): # noqa
1527
+ self ._test_resetting_strategies (heterogeneous , {})
1528
+
1529
+ @implement_for ("gymnasium" , "1.1.0" )
1530
+ @pytest .mark .parametrize ("heterogeneous" , [False , True ])
1531
+ def test_resetting_strategies (self , heterogeneous ): # noqa
1532
+ import gymnasium as gym
1533
+
1534
+ self ._test_resetting_strategies (
1535
+ heterogeneous , {"autoreset_mode" : gym .vector .AutoresetMode .SAME_STEP }
1536
+ )
1537
+
1538
+ def _test_resetting_strategies (self , heterogeneous , kwargs ):
1357
1539
if _has_gymnasium :
1358
1540
backend = "gymnasium"
1359
1541
else :
@@ -1369,7 +1551,8 @@ def test_resetting_strategies(self, heterogeneous):
1369
1551
env = GymWrapper (
1370
1552
gym_backend ().vector .AsyncVectorEnv (
1371
1553
[functools .partial (self ._get_dummy_gym_env , backend = backend )]
1372
- * 4
1554
+ * 4 ,
1555
+ ** kwargs ,
1373
1556
)
1374
1557
)
1375
1558
else :
@@ -1382,7 +1565,8 @@ def test_resetting_strategies(self, heterogeneous):
1382
1565
backend = backend ,
1383
1566
)
1384
1567
for i in range (4 )
1385
- ]
1568
+ ],
1569
+ ** kwargs ,
1386
1570
)
1387
1571
)
1388
1572
try :
@@ -1461,6 +1645,12 @@ def _make_gym_environment(env_name): # noqa: F811
1461
1645
return gym .make (env_name , render_mode = "rgb_array" )
1462
1646
1463
1647
1648
+ @implement_for ("gymnasium" , "1.1.0" )
1649
+ def _make_gym_environment (env_name ): # noqa: F811
1650
+ gym = gym_backend ()
1651
+ return gym .make (env_name , render_mode = "rgb_array" )
1652
+
1653
+
1464
1654
@pytest .mark .skipif (not _has_dmc , reason = "no dm_control library found" )
1465
1655
class TestDMControl :
1466
1656
@pytest .mark .parametrize ("env_name,task" , [["cheetah" , "run" ]])
0 commit comments