@@ -287,7 +287,8 @@ def __init__(self, env, eval_episodes=3):
287287 self ._eval_episodes = eval_episodes
288288 self ._was_real_done = False
289289 self ._eval_rewards = None
290- self ._end_episode = len (env .get_episode_rewards ()) + eval_episodes
290+ self ._end_episode = len (
291+ self ._monitor .get_episode_rewards ()) + eval_episodes
291292
292293 def step (self , action ):
293294 ob , reward , done , info = self ._env .step (action )
@@ -318,12 +319,20 @@ def _get_curr_episode(self):
318319 return len (self ._monitor .get_episode_rewards ())
319320
320321
321- def wrap_deepmind (env , dim = 84 , framestack = True , obs_format = 'NHWC' , test = False ):
322+ def wrap_deepmind (env ,
323+ dim = 84 ,
324+ framestack = True ,
325+ obs_format = 'NHWC' ,
326+ test = False ,
327+ eval_episodes = 3 ):
322328 """Configure environment for DeepMind-style Atari.
323329
324330 Args:
325331 dim (int): Dimension to resize observations to (dim x dim).
326332 framestack (bool): Whether to framestack observations.
333+ obs_format (str): observation output format
334+ test (bool): whether this is a test env
335+ eval_episodes (int): when test, number of episodes for each evaluation
327336 """
328337 env = MonitorEnv (env )
329338 env = NoopResetEnv (env , noop_max = 30 )
@@ -337,5 +346,5 @@ def wrap_deepmind(env, dim=84, framestack=True, obs_format='NHWC', test=False):
337346 if framestack :
338347 env = FrameStack (env , 4 , obs_format )
339348 if test :
340- env = TestEnv (env )
349+ env = TestEnv (env , eval_episodes )
341350 return env
0 commit comments