100100
101101
102102@pytest .mark .skipif (not _has_gym , reason = "no gym library found" )
103- @pytest .mark .parametrize (
104- "env_name" ,
105- [
106- PONG_VERSIONED ,
107- # PENDULUM_VERSIONED,
108- HALFCHEETAH_VERSIONED ,
109- ],
110- )
111- @pytest .mark .parametrize ("frame_skip" , [1 , 3 ])
112- @pytest .mark .parametrize (
113- "from_pixels,pixels_only" ,
114- [
115- [False , False ],
116- [True , True ],
117- [True , False ],
118- ],
119- )
120103class TestGym :
104+ @pytest .mark .parametrize (
105+ "env_name" ,
106+ [
107+ PONG_VERSIONED ,
108+ # PENDULUM_VERSIONED,
109+ HALFCHEETAH_VERSIONED ,
110+ ],
111+ )
112+ @pytest .mark .parametrize ("frame_skip" , [1 , 3 ])
113+ @pytest .mark .parametrize (
114+ "from_pixels,pixels_only" ,
115+ [
116+ [False , False ],
117+ [True , True ],
118+ [True , False ],
119+ ],
120+ )
121121 def test_gym (self , env_name , frame_skip , from_pixels , pixels_only ):
122122 if env_name == PONG_VERSIONED and not from_pixels :
123123 # raise pytest.skip("already pixel")
@@ -176,6 +176,23 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only):
176176 assert final_seed0 == final_seed2
177177 assert_allclose_td (tdrollout [0 ], rollout2 , rtol = RTOL , atol = ATOL )
178178
179+ @pytest .mark .parametrize (
180+ "env_name" ,
181+ [
182+ PONG_VERSIONED ,
183+ # PENDULUM_VERSIONED,
184+ HALFCHEETAH_VERSIONED ,
185+ ],
186+ )
187+ @pytest .mark .parametrize ("frame_skip" , [1 , 3 ])
188+ @pytest .mark .parametrize (
189+ "from_pixels,pixels_only" ,
190+ [
191+ [False , False ],
192+ [True , True ],
193+ [True , False ],
194+ ],
195+ )
179196 def test_gym_fake_td (self , env_name , frame_skip , from_pixels , pixels_only ):
180197 if env_name == PONG_VERSIONED and not from_pixels :
181198 # raise pytest.skip("already pixel")
@@ -195,6 +212,37 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
195212 )
196213 check_env_specs (env )
197214
215+ def test_info_reader (self ):
216+ try :
217+ import gym_super_mario_bros as mario_gym
218+ except ImportError as err :
219+ try :
220+ import gym
221+
222+ # with 0.26 we must have installed gym_super_mario_bros
223+ # Since we capture the skips as errors, we raise a skip in this case
224+ # Otherwise, we just return
225+ if (
226+ version .parse ("0.26.0" )
227+ <= version .parse (gym .__version__ )
228+ < version .parse ("0.27.0" )
229+ ):
230+ raise pytest .skip (f"no super mario bros: error=\n { err } " )
231+ except ImportError :
232+ pass
233+ return
234+
235+ env = mario_gym .make ("SuperMarioBros-v0" , apply_api_compatibility = True )
236+ env = GymWrapper (env )
237+
238+ def info_reader (info , tensordict ):
239+ assert isinstance (info , dict ) # failed before bugfix
240+
241+ env .info_dict_reader = info_reader
242+ env .reset ()
243+ env .rand_step ()
244+ env .rollout (3 )
245+
198246
199247@implement_for ("gym" , None , "0.26" )
200248def _make_gym_environment (env_name ): # noqa: F811
0 commit comments