Skip to content

Commit 95f90b1

Browse files
authored
update atari_wrappers (#686)
1 parent d1eef93 commit 95f90b1

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

parl/env/atari_wrappers.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,8 @@ def __init__(self, env, eval_episodes=3):
287287
self._eval_episodes = eval_episodes
288288
self._was_real_done = False
289289
self._eval_rewards = None
290-
self._end_episode = len(env.get_episode_rewards()) + eval_episodes
290+
self._end_episode = len(
291+
self._monitor.get_episode_rewards()) + eval_episodes
291292

292293
def step(self, action):
293294
ob, reward, done, info = self._env.step(action)
@@ -318,12 +319,20 @@ def _get_curr_episode(self):
318319
return len(self._monitor.get_episode_rewards())
319320

320321

321-
def wrap_deepmind(env, dim=84, framestack=True, obs_format='NHWC', test=False):
322+
def wrap_deepmind(env,
323+
dim=84,
324+
framestack=True,
325+
obs_format='NHWC',
326+
test=False,
327+
eval_episodes=3):
322328
"""Configure environment for DeepMind-style Atari.
323329
324330
Args:
325331
dim (int): Dimension to resize observations to (dim x dim).
326332
framestack (bool): Whether to framestack observations.
333+
obs_format (str): observation output format
334+
test (bool): whether this is a test env
335+
eval_episodes (int): when test, number of episodes for each evaluation
327336
"""
328337
env = MonitorEnv(env)
329338
env = NoopResetEnv(env, noop_max=30)
@@ -337,5 +346,5 @@ def wrap_deepmind(env, dim=84, framestack=True, obs_format='NHWC', test=False):
337346
if framestack:
338347
env = FrameStack(env, 4, obs_format)
339348
if test:
340-
env = TestEnv(env)
349+
env = TestEnv(env, eval_episodes)
341350
return env

0 commit comments

Comments
 (0)