Skip to content

Commit 78cd755

Browse files
author
Vincent Moens
committed
[Feature] Gymnasium 1.1 compatibility
ghstack-source-id: e089186 Pull Request resolved: #2898
1 parent 8ce11a8 commit 78cd755

File tree

10 files changed

+506
-49
lines changed

10 files changed

+506
-49
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ echo "installing gymnasium"
9797
if [[ "$PYTHON_VERSION" == "3.12" ]]; then
9898
pip3 install ale-py
9999
pip3 install sympy
100-
pip3 install "gymnasium[accept-rom-license,mujoco]<1.0" mo-gymnasium[mujoco]
100+
pip3 install "gymnasium[accept-rom-license,mujoco]>=1.1" mo-gymnasium[mujoco]
101101
else
102-
pip3 install "gymnasium[atari,accept-rom-license,mujoco]<1.0" mo-gymnasium[mujoco]
102+
pip3 install "gymnasium[atari,accept-rom-license,mujoco]>=1.1" mo-gymnasium[mujoco]
103103
fi
104104
pip3 install "mujoco" -U
105105

.github/unittest/linux_distributed/scripts/setup_env.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ if [[ $OSTYPE != 'darwin'* ]]; then
121121
rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
122122
fi
123123
echo "installing gymnasium"
124-
pip install "gymnasium[atari,accept-rom-license]<1.0"
124+
pip install "gymnasium[atari,accept-rom-license]>=1.1"
125125
else
126-
pip install "gymnasium[atari,accept-rom-license]<1.0"
126+
pip install "gymnasium[atari,accept-rom-license]>=1.1"
127127
fi

.github/unittest/linux_libs/scripts_gym/batch_scripts.sh

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,27 @@ do
135135
conda env remove --prefix ./cloned_env -y
136136
done
137137

138+
# Prev gymnasium
139+
conda deactivate
140+
conda create --prefix ./cloned_env --clone ./env -y
141+
conda activate ./cloned_env
142+
143+
pip3 install 'gymnasium[accept-rom-license,ale-py,atari]>=1.1.0' mo-gymnasium gymnasium-robotics -U
144+
145+
$DIR/run_test.sh
146+
147+
# delete the conda copy
148+
conda deactivate
149+
conda env remove --prefix ./cloned_env -y
150+
151+
# Skip 1.0.0
152+
138153
# Latest gymnasium
139154
conda deactivate
140155
conda create --prefix ./cloned_env --clone ./env -y
141156
conda activate ./cloned_env
142157

143-
pip3 install 'gymnasium[accept-rom-license,ale-py,atari]<1.0' mo-gymnasium gymnasium-robotics -U
158+
pip3 install 'gymnasium[accept-rom-license,ale-py,atari]>=1.1.0' mo-gymnasium gymnasium-robotics -U
144159

145160
$DIR/run_test.sh
146161

.github/unittest/linux_sota/scripts/run_all.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ python -c """import gym;import d4rl"""
112112

113113
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
114114
# rename them
115-
pip install "gymnasium[atari,accept-rom-license]<1.0"
115+
pip install "gymnasium[atari,accept-rom-license]>=1.1.0"
116116

117117
# ============================================================================================ #
118118
# ================================ PyTorch & TorchRL ========================================= #

test/_utils_internal.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,22 @@ def _set_gym_environments(): # noqa: F811
146146
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"
147147

148148

149-
@implement_for("gymnasium", "1.0.0", None)
149+
@implement_for("gymnasium", "1.0.0", "1.1.0")
150150
def _set_gym_environments(): # noqa: F811
151151
raise ImportError
152152

153153

154+
@implement_for("gymnasium", "1.1.0")
155+
def _set_gym_environments(): # noqa: F811
156+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED
157+
158+
_CARTPOLE_VERSIONED = "CartPole-v1"
159+
_HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
160+
_PENDULUM_VERSIONED = "Pendulum-v1"
161+
_PONG_VERSIONED = "ALE/Pong-v5"
162+
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"
163+
164+
154165
if _has_gym:
155166
_set_gym_environments()
156167

test/test_libs.py

Lines changed: 196 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,21 @@ def _make_spec( # noqa: F811
318318
shape=batch_size,
319319
)
320320

321+
@implement_for("gymnasium", "1.1.0")
322+
def _make_spec( # noqa: F811
323+
self, batch_size, cat, cat_shape, multicat, multicat_shape
324+
):
325+
return Composite(
326+
a=Unbounded(shape=(*batch_size, 1)),
327+
b=Composite(c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size),
328+
d=cat(5, shape=cat_shape, dtype=torch.int64),
329+
e=multicat([2, 3], shape=(*batch_size, multicat_shape), dtype=torch.int64),
330+
f=Bounded(-3, 4, shape=(*batch_size, 1)),
331+
g=UnboundedDiscreteTensorSpec(shape=(*batch_size, 1), dtype=torch.long),
332+
h=Binary(n=5, shape=(*batch_size, 5)),
333+
shape=batch_size,
334+
)
335+
321336
@pytest.mark.parametrize("categorical", [True, False])
322337
def test_gym_spec_cast(self, categorical):
323338
batch_size = [3, 4]
@@ -379,10 +394,17 @@ def test_gym_spec_cast_tuple_sequential(self, order):
379394
torchrl_logger.info("Sequence not available in gym")
380395
return
381396

382-
# @pytest.mark.parametrize("order", ["seq_tuple", "tuple_seq"])
397+
@pytest.mark.parametrize("order", ["tuple_seq"])
398+
@implement_for("gymnasium", "1.1.0")
399+
def test_gym_spec_cast_tuple_sequential(self, order): # noqa: F811
400+
self._test_gym_spec_cast_tuple_sequential(order)
401+
383402
@pytest.mark.parametrize("order", ["tuple_seq"])
384403
@implement_for("gymnasium", None, "1.0.0")
385404
def test_gym_spec_cast_tuple_sequential(self, order): # noqa: F811
405+
self._test_gym_spec_cast_tuple_sequential(order)
406+
407+
def _test_gym_spec_cast_tuple_sequential(self, order): # noqa: F811
386408
with set_gym_backend("gymnasium"):
387409
if order == "seq_tuple":
388410
# Requires nested tensors to be created along dim=1, disabling
@@ -974,8 +996,15 @@ def info_reader(info, tensordict):
974996
finally:
975997
set_gym_backend(gb).set()
976998

977-
@implement_for("gymnasium", None, "1.0.0")
999+
@implement_for("gymnasium", "1.1.0")
9781000
def test_one_hot_and_categorical(self):
1001+
self._test_one_hot_and_categorical()
1002+
1003+
@implement_for("gymnasium", None, "1.0.0")
1004+
def test_one_hot_and_categorical(self): # noqa
1005+
self._test_one_hot_and_categorical()
1006+
1007+
def _test_one_hot_and_categorical(self):
9791008
# tests that one-hot and categorical work ok when an integer is expected as action
9801009
cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True)
9811010
cliff_walking.rollout(10)
@@ -993,14 +1022,27 @@ def test_one_hot_and_categorical(self): # noqa: F811
9931022
# versions.
9941023
return
9951024

996-
@implement_for("gymnasium", None, "1.0.0")
1025+
@implement_for("gymnasium", "1.1.0")
9971026
@pytest.mark.parametrize(
9981027
"envname",
9991028
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
10001029
+ (["FetchReach-v2"] if _has_gym_robotics else []),
10011030
)
10021031
@pytest.mark.flaky(reruns=5, reruns_delay=1)
10031032
def test_vecenvs_wrapper(self, envname):
1033+
self._test_vecenvs_wrapper(envname)
1034+
1035+
@implement_for("gymnasium", None, "1.0.0")
1036+
@pytest.mark.parametrize(
1037+
"envname",
1038+
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
1039+
+ (["FetchReach-v2"] if _has_gym_robotics else []),
1040+
)
1041+
@pytest.mark.flaky(reruns=5, reruns_delay=1)
1042+
def test_vecenvs_wrapper(self, envname): # noqa
1043+
self._test_vecenvs_wrapper(envname)
1044+
1045+
def _test_vecenvs_wrapper(self, envname):
10041046
import gymnasium
10051047

10061048
# we can't use parametrize with implement_for
@@ -1019,7 +1061,7 @@ def test_vecenvs_wrapper(self, envname):
10191061
assert env.batch_size == torch.Size([2])
10201062
check_env_specs(env)
10211063

1022-
@implement_for("gymnasium", None, "1.0.0")
1064+
@implement_for("gymnasium", "1.1.0")
10231065
# this env has Dict-based observation which is a nice thing to test
10241066
@pytest.mark.parametrize(
10251067
"envname",
@@ -1028,6 +1070,21 @@ def test_vecenvs_wrapper(self, envname):
10281070
)
10291071
@pytest.mark.flaky(reruns=5, reruns_delay=1)
10301072
def test_vecenvs_env(self, envname):
1073+
self._test_vecenvs_env(envname)
1074+
1075+
@implement_for("gymnasium", None, "1.0.0")
1076+
# this env has Dict-based observation which is a nice thing to test
1077+
@pytest.mark.parametrize(
1078+
"envname",
1079+
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
1080+
+ (["FetchReach-v2"] if _has_gym_robotics else []),
1081+
)
1082+
@pytest.mark.flaky(reruns=5, reruns_delay=1)
1083+
def test_vecenvs_env(self, envname): # noqa
1084+
self._test_vecenvs_env(envname)
1085+
1086+
def _test_vecenvs_env(self, envname):
1087+
10311088
gb = gym_backend()
10321089
try:
10331090
with set_gym_backend("gymnasium"):
@@ -1181,9 +1238,17 @@ def test_gym_output_num(self, wrapper): # noqa: F811
11811238
finally:
11821239
set_gym_backend(gym).set()
11831240

1241+
@implement_for("gymnasium", "1.1.0")
1242+
@pytest.mark.parametrize("wrapper", [True, False])
1243+
def test_gym_output_num(self, wrapper): # noqa: F811
1244+
self._test_gym_output_num(wrapper)
1245+
11841246
@implement_for("gymnasium", None, "1.0.0")
11851247
@pytest.mark.parametrize("wrapper", [True, False])
11861248
def test_gym_output_num(self, wrapper): # noqa: F811
1249+
self._test_gym_output_num(wrapper)
1250+
1251+
def _test_gym_output_num(self, wrapper): # noqa: F811
11871252
# gym has 5 outputs, with truncation
11881253
gym = gym_backend()
11891254
try:
@@ -1284,8 +1349,15 @@ def test_vecenvs_nan(self): # noqa: F811
12841349
del c
12851350
return
12861351

1352+
@implement_for("gymnasium", "1.1.0")
1353+
def test_vecenvs_nan(self): # noqa: F811
1354+
self._test_vecenvs_nan()
1355+
12871356
@implement_for("gymnasium", None, "1.0.0")
12881357
def test_vecenvs_nan(self): # noqa: F811
1358+
self._test_vecenvs_nan()
1359+
1360+
def _test_vecenvs_nan(self): # noqa: F811
12891361
# new versions of gym must never return nan for next values when there is a done state
12901362
torch.manual_seed(0)
12911363
env = GymEnv("CartPole-v1", num_envs=2)
@@ -1352,8 +1424,118 @@ def step(self, action):
13521424

13531425
return CustomEnv(**kwargs)
13541426

1427+
def counting_env(self):
1428+
import gymnasium as gym
1429+
from gymnasium import Env
1430+
1431+
class CountingEnvRandomReset(Env):
1432+
def __init__(self, i=0):
1433+
self.counter = 1
1434+
self.i = i
1435+
self.observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(1,))
1436+
self.action_space = gym.spaces.Box(-np.inf, np.inf, shape=(1,))
1437+
self.rng = np.random.RandomState(0)
1438+
1439+
def step(self, action):
1440+
self.counter += 1
1441+
done = bool(self.rng.random() < 0.05)
1442+
return (
1443+
np.asarray(
1444+
[
1445+
self.counter,
1446+
]
1447+
),
1448+
0,
1449+
done,
1450+
done,
1451+
{},
1452+
)
1453+
1454+
def reset(
1455+
self,
1456+
*,
1457+
seed: int | None = None,
1458+
options=None,
1459+
):
1460+
self.counter = 1
1461+
if seed is not None:
1462+
self.rng = np.random.RandomState(seed)
1463+
return (
1464+
np.asarray(
1465+
[
1466+
self.counter,
1467+
]
1468+
),
1469+
{},
1470+
)
1471+
1472+
yield CountingEnvRandomReset
1473+
1474+
@implement_for("gym")
1475+
def test_gymnasium_autoreset(self, venv):
1476+
return
1477+
1478+
@implement_for("gymnasium", None, "1.1.0")
1479+
def test_gymnasium_autoreset(self, venv): # noqa
1480+
return
1481+
1482+
@implement_for("gymnasium", "1.1.0")
1483+
@pytest.mark.parametrize("venv", ["sync", "async"])
1484+
def test_gymnasium_autoreset(self, venv): # noqa
1485+
import gymnasium as gym
1486+
1487+
counting_env = self.counting_env()
1488+
if venv == "sync":
1489+
venv = gym.vector.SyncVectorEnv
1490+
else:
1491+
venv = gym.vector.AsyncVectorEnv
1492+
envs0 = venv(
1493+
[lambda i=i: counting_env(i) for i in range(2)],
1494+
autoreset_mode=gym.vector.AutoresetMode.DISABLED,
1495+
)
1496+
env = GymWrapper(envs0)
1497+
envs0.reset(seed=0)
1498+
torch.manual_seed(0)
1499+
r0 = env.rollout(20, break_when_any_done=False)
1500+
envs1 = venv(
1501+
[lambda i=i: counting_env(i) for i in range(2)],
1502+
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
1503+
)
1504+
env = GymWrapper(envs1)
1505+
envs1.reset(seed=0)
1506+
# env.set_seed(0)
1507+
torch.manual_seed(0)
1508+
r1 = []
1509+
t_ = env.reset()
1510+
for s in r0.unbind(-1):
1511+
t_.set("action", s["action"])
1512+
t, t_ = env.step_and_maybe_reset(t_)
1513+
r1.append(t)
1514+
r1 = torch.stack(r1, -1)
1515+
torch.testing.assert_close(r0["observation"], r1["observation"])
1516+
torch.testing.assert_close(r0["next", "observation"], r1["next", "observation"])
1517+
torch.testing.assert_close(r0["next", "done"], r1["next", "done"])
1518+
1519+
@implement_for("gym")
13551520
@pytest.mark.parametrize("heterogeneous", [False, True])
13561521
def test_resetting_strategies(self, heterogeneous):
1522+
return
1523+
1524+
@implement_for("gymnasium", None, "1.0.0")
1525+
@pytest.mark.parametrize("heterogeneous", [False, True])
1526+
def test_resetting_strategies(self, heterogeneous): # noqa
1527+
self._test_resetting_strategies(heterogeneous, {})
1528+
1529+
@implement_for("gymnasium", "1.1.0")
1530+
@pytest.mark.parametrize("heterogeneous", [False, True])
1531+
def test_resetting_strategies(self, heterogeneous): # noqa
1532+
import gymnasium as gym
1533+
1534+
self._test_resetting_strategies(
1535+
heterogeneous, {"autoreset_mode": gym.vector.AutoresetMode.SAME_STEP}
1536+
)
1537+
1538+
def _test_resetting_strategies(self, heterogeneous, kwargs):
13571539
if _has_gymnasium:
13581540
backend = "gymnasium"
13591541
else:
@@ -1369,7 +1551,8 @@ def test_resetting_strategies(self, heterogeneous):
13691551
env = GymWrapper(
13701552
gym_backend().vector.AsyncVectorEnv(
13711553
[functools.partial(self._get_dummy_gym_env, backend=backend)]
1372-
* 4
1554+
* 4,
1555+
**kwargs,
13731556
)
13741557
)
13751558
else:
@@ -1382,7 +1565,8 @@ def test_resetting_strategies(self, heterogeneous):
13821565
backend=backend,
13831566
)
13841567
for i in range(4)
1385-
]
1568+
],
1569+
**kwargs,
13861570
)
13871571
)
13881572
try:
@@ -1461,6 +1645,12 @@ def _make_gym_environment(env_name): # noqa: F811
14611645
return gym.make(env_name, render_mode="rgb_array")
14621646

14631647

1648+
@implement_for("gymnasium", "1.1.0")
1649+
def _make_gym_environment(env_name): # noqa: F811
1650+
gym = gym_backend()
1651+
return gym.make(env_name, render_mode="rgb_array")
1652+
1653+
14641654
@pytest.mark.skipif(not _has_dmc, reason="no dm_control library found")
14651655
class TestDMControl:
14661656
@pytest.mark.parametrize("env_name,task", [["cheetah", "run"]])

0 commit comments

Comments
 (0)