Skip to content

Commit f247345

Browse files
authored
[BugFix,Test] Update CliffWalking version (#3045)
1 parent ad9c5ee commit f247345

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

test/_utils_internal.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import os.path
1111
import sys
12+
import sys.version_info
1213
import time
1314
import unittest
1415
import warnings
@@ -53,6 +54,8 @@
5354
else:
5455
mp_ctx = "fork"
5556

57+
PYTHON_3_9 = sys.version_info.major == 3 and sys.version_info.minor <= 9
58+
5659

5760
def CARTPOLE_VERSIONED():
5861
# load gym
@@ -82,6 +85,12 @@ def PONG_VERSIONED():
8285
return _PONG_VERSIONED
8386

8487

88+
def CLIFFWALKING_VERSIONED():
89+
if gym_backend() is not None:
90+
_set_gym_environments()
91+
return _CLIFFWALKING_VERSIONED
92+
93+
8594
def BREAKOUT_VERSIONED():
8695
# load gym
8796
# Gymnasium says that the ale_py behavior changes from 1.0
@@ -104,46 +113,50 @@ def PENDULUM_VERSIONED():
104113

105114

106115
def _set_gym_environments():
107-
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED
116+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED
108117

109118
_CARTPOLE_VERSIONED = None
110119
_HALFCHEETAH_VERSIONED = None
111120
_PENDULUM_VERSIONED = None
112121
_PONG_VERSIONED = None
113122
_BREAKOUT_VERSIONED = None
123+
_CLIFFWALKING_VERSIONED = None
114124

115125

116126
@implement_for("gym", None, "0.21.0")
117127
def _set_gym_environments(): # noqa: F811
118-
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED
128+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED
119129

120130
_CARTPOLE_VERSIONED = "CartPole-v0"
121131
_HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
122132
_PENDULUM_VERSIONED = "Pendulum-v0"
123133
_PONG_VERSIONED = "Pong-v4"
124134
_BREAKOUT_VERSIONED = "Breakout-v4"
135+
_CLIFFWALKING_VERSIONED = "CliffWalking-v0"
125136

126137

127138
@implement_for("gym", "0.21.0", None)
128139
def _set_gym_environments(): # noqa: F811
129-
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED
140+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED
130141

131142
_CARTPOLE_VERSIONED = "CartPole-v1"
132143
_HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
133144
_PENDULUM_VERSIONED = "Pendulum-v1"
134145
_PONG_VERSIONED = "ALE/Pong-v5"
135146
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"
147+
_CLIFFWALKING_VERSIONED = "CliffWalking-v0"
136148

137149

138150
@implement_for("gymnasium", None, "1.0.0")
139151
def _set_gym_environments(): # noqa: F811
140-
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED
152+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED
141153

142154
_CARTPOLE_VERSIONED = "CartPole-v1"
143155
_HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
144156
_PENDULUM_VERSIONED = "Pendulum-v1"
145157
_PONG_VERSIONED = "ALE/Pong-v5"
146158
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"
159+
_CLIFFWALKING_VERSIONED = "CliffWalking-v0"
147160

148161

149162
@implement_for("gymnasium", "1.0.0", "1.1.0")
@@ -153,13 +166,14 @@ def _set_gym_environments(): # noqa: F811
153166

154167
@implement_for("gymnasium", "1.1.0")
155168
def _set_gym_environments(): # noqa: F811
156-
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED
169+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED
157170

158171
_CARTPOLE_VERSIONED = "CartPole-v1"
159172
_HALFCHEETAH_VERSIONED = "HalfCheetah-v5"
160173
_PENDULUM_VERSIONED = "Pendulum-v1"
161174
_PONG_VERSIONED = "ALE/Pong-v5"
162175
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"
176+
_CLIFFWALKING_VERSIONED = "CliffWalking-v1" if not PYTHON_3_9 else "CliffWalking-v0"
163177

164178

165179
if _has_gym:

test/test_libs.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
from pytorch.rl.test._utils_internal import (
134134
_make_multithreaded_env,
135135
CARTPOLE_VERSIONED,
136+
CLIFFWALKING_VERSIONED,
136137
get_available_devices,
137138
get_default_devices,
138139
HALFCHEETAH_VERSIONED,
@@ -146,6 +147,7 @@
146147
from _utils_internal import (
147148
_make_multithreaded_env,
148149
CARTPOLE_VERSIONED,
150+
CLIFFWALKING_VERSIONED,
149151
get_available_devices,
150152
get_default_devices,
151153
HALFCHEETAH_VERSIONED,
@@ -1028,11 +1030,15 @@ def test_one_hot_and_categorical(self): # noqa
10281030

10291031
def _test_one_hot_and_categorical(self):
10301032
# tests that one-hot and categorical work ok when an integer is expected as action
1031-
cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True)
1033+
cliff_walking = GymEnv(
1034+
CLIFFWALKING_VERSIONED(), categorical_action_encoding=True
1035+
)
10321036
cliff_walking.rollout(10)
10331037
check_env_specs(cliff_walking)
10341038

1035-
cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=False)
1039+
cliff_walking = GymEnv(
1040+
CLIFFWALKING_VERSIONED(), categorical_action_encoding=False
1041+
)
10361042
cliff_walking.rollout(10)
10371043
check_env_specs(cliff_walking)
10381044

0 commit comments

Comments
 (0)