2424from torchrl .envs import MultiThreadedEnv , ObservationNorm
2525from torchrl .envs .batched_envs import ParallelEnv , SerialEnv
2626from torchrl .envs .libs .envpool import _has_envpool
27- from torchrl .envs .libs .gym import _has_gym , GymEnv
27+ from torchrl .envs .libs .gym import _has_gym , gym_backend , GymEnv
2828from torchrl .envs .transforms import (
2929 Compose ,
3030 RewardClipping ,
3535# Specified for test_utils.py
3636__version__ = "0.3"
3737
38- # Default versions of the environments.
39- CARTPOLE_VERSIONED = "CartPole-v1"
40- HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
41- PENDULUM_VERSIONED = "Pendulum-v1"
42- PONG_VERSIONED = "ALE/Pong-v5"
38+
39+ def CARTPOLE_VERSIONED ():
40+ # load gym
41+ if gym_backend () is not None :
42+ _set_gym_environments ()
43+ return _CARTPOLE_VERSIONED
44+
45+
46+ def HALFCHEETAH_VERSIONED ():
47+ # load gym
48+ if gym_backend () is not None :
49+ _set_gym_environments ()
50+ return _HALFCHEETAH_VERSIONED
51+
52+
53+ def PONG_VERSIONED ():
54+ # load gym
55+ if gym_backend () is not None :
56+ _set_gym_environments ()
57+ return _PONG_VERSIONED
58+
59+
60+ def PENDULUM_VERSIONED ():
61+ # load gym
62+ if gym_backend () is not None :
63+ _set_gym_environments ()
64+ return _PENDULUM_VERSIONED
65+
66+
67+ def _set_gym_environments ():
68+ global _CARTPOLE_VERSIONED , _HALFCHEETAH_VERSIONED , _PENDULUM_VERSIONED , _PONG_VERSIONED
69+
70+ _CARTPOLE_VERSIONED = None
71+ _HALFCHEETAH_VERSIONED = None
72+ _PENDULUM_VERSIONED = None
73+ _PONG_VERSIONED = None
4374
4475
4576@implement_for ("gym" , None , "0.21.0" )
4677def _set_gym_environments (): # noqa: F811
47- global CARTPOLE_VERSIONED , HALFCHEETAH_VERSIONED , PENDULUM_VERSIONED , PONG_VERSIONED
78+ global _CARTPOLE_VERSIONED , _HALFCHEETAH_VERSIONED , _PENDULUM_VERSIONED , _PONG_VERSIONED
4879
49- CARTPOLE_VERSIONED = "CartPole-v0"
50- HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
51- PENDULUM_VERSIONED = "Pendulum-v0"
52- PONG_VERSIONED = "Pong-v4"
80+ _CARTPOLE_VERSIONED = "CartPole-v0"
81+ _HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
82+ _PENDULUM_VERSIONED = "Pendulum-v0"
83+ _PONG_VERSIONED = "Pong-v4"
5384
5485
5586@implement_for ("gym" , "0.21.0" , None )
5687def _set_gym_environments (): # noqa: F811
57- global CARTPOLE_VERSIONED , HALFCHEETAH_VERSIONED , PENDULUM_VERSIONED , PONG_VERSIONED
88+ global _CARTPOLE_VERSIONED , _HALFCHEETAH_VERSIONED , _PENDULUM_VERSIONED , _PONG_VERSIONED
5889
59- CARTPOLE_VERSIONED = "CartPole-v1"
60- HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
61- PENDULUM_VERSIONED = "Pendulum-v1"
62- PONG_VERSIONED = "ALE/Pong-v5"
90+ _CARTPOLE_VERSIONED = "CartPole-v1"
91+ _HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
92+ _PENDULUM_VERSIONED = "Pendulum-v1"
93+ _PONG_VERSIONED = "ALE/Pong-v5"
6394
6495
6596@implement_for ("gymnasium" )
6697def _set_gym_environments (): # noqa: F811
67- global CARTPOLE_VERSIONED , HALFCHEETAH_VERSIONED , PENDULUM_VERSIONED , PONG_VERSIONED
98+ global _CARTPOLE_VERSIONED , _HALFCHEETAH_VERSIONED , _PENDULUM_VERSIONED , _PONG_VERSIONED
6899
69- CARTPOLE_VERSIONED = "CartPole-v1"
70- HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
71- PENDULUM_VERSIONED = "Pendulum-v1"
72- PONG_VERSIONED = "ALE/Pong-v5"
100+ _CARTPOLE_VERSIONED = "CartPole-v1"
101+ _HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
102+ _PENDULUM_VERSIONED = "Pendulum-v1"
103+ _PONG_VERSIONED = "ALE/Pong-v5"
73104
74105
75106if _has_gym :
@@ -171,7 +202,7 @@ def create_env_fn():
171202 return GymEnv (env_name , frame_skip = frame_skip , device = device )
172203
173204 else :
174- if env_name == PONG_VERSIONED :
205+ if env_name == PONG_VERSIONED () :
175206
176207 def create_env_fn ():
177208 base_env = GymEnv (env_name , frame_skip = frame_skip , device = device )
@@ -250,7 +281,7 @@ def _make_multithreaded_env(
250281
251282 torch .manual_seed (0 )
252283 multithreaded_kwargs = (
253- {"frame_skip" : frame_skip } if env_name == PONG_VERSIONED else {}
284+ {"frame_skip" : frame_skip } if env_name == PONG_VERSIONED () else {}
254285 )
255286 env_multithread = MultiThreadedEnv (
256287 N ,
@@ -274,7 +305,7 @@ def _make_multithreaded_env(
274305
275306def get_transform_out (env_name , transformed_in , obs_key = None ):
276307
277- if env_name == PONG_VERSIONED :
308+ if env_name == PONG_VERSIONED () :
278309 if obs_key is None :
279310 obs_key = "pixels"
280311
0 commit comments