Skip to content

Commit 5ee9009

Browse files
qgallouedecaraffin
andauthored
Add sticky actions for Atari games (#1286)
* repeat_action_probability * Add test * Undo atari wrapper doc change since CI fails * remove action_repeat_probability from make_atari_env * Add sticky action wrapper and improve documentation * Update changelog * handle the case noop_max=0 * Update tests * Comply to ALE implementation * Reorder doc * Add doc warning and don't wrap with sticky action when not needed * fix docstring and reorder * Move `action_repeat_probability` args at the last position * Add ref * Update doc and wrap with frameskip only if needed * Update changelog Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 637988c commit 5ee9009

File tree

4 files changed

+110
-41
lines changed

4 files changed

+110
-41
lines changed

docs/misc/changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Release 1.8.0a2 (WIP)
7+
Release 1.8.0a3 (WIP)
88
--------------------------
99

1010

@@ -14,6 +14,8 @@ Breaking Changes:
1414

1515
New Features:
1616
^^^^^^^^^^^^^
17+
- Added ``repeat_action_probability`` argument in ``AtariWrapper``.
18+
- Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper``
1719

1820
`SB3-Contrib`_
1921
^^^^^^^^^^^^^^

stable_baselines3/common/atari_wrappers.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,39 @@
1212
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
1313

1414

15+
class StickyActionEnv(gym.Wrapper):
16+
"""
17+
Sticky action.
18+
19+
Paper: https://arxiv.org/abs/1709.06009
20+
Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment
21+
22+
:param env: Environment to wrap
23+
:param action_repeat_probability: Probability of repeating the last action
24+
"""
25+
26+
def __init__(self, env: gym.Env, action_repeat_probability: float) -> None:
27+
super().__init__(env)
28+
self.action_repeat_probability = action_repeat_probability
29+
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
30+
31+
def reset(self, **kwargs) -> GymObs:
32+
self._sticky_action = 0 # NOOP
33+
return self.env.reset(**kwargs)
34+
35+
def step(self, action: int) -> GymStepReturn:
36+
if self.np_random.random() >= self.action_repeat_probability:
37+
self._sticky_action = action
38+
return self.env.step(self._sticky_action)
39+
40+
1541
class NoopResetEnv(gym.Wrapper):
1642
"""
1743
Sample initial states by taking random number of no-ops on reset.
1844
No-op is assumed to be action 0.
1945
20-
:param env: the environment to wrap
21-
:param noop_max: the maximum value of no-ops to run
46+
:param env: Environment to wrap
47+
:param noop_max: Maximum value of no-ops to run
2248
"""
2349

2450
def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
@@ -47,7 +73,7 @@ class FireResetEnv(gym.Wrapper):
4773
"""
4874
Take action on reset for environments that are fixed until firing.
4975
50-
:param env: the environment to wrap
76+
:param env: Environment to wrap
5177
"""
5278

5379
def __init__(self, env: gym.Env) -> None:
@@ -71,7 +97,7 @@ class EpisodicLifeEnv(gym.Wrapper):
7197
Make end-of-life == end-of-episode, but only reset on true game over.
7298
Done by DeepMind for the DQN and co. since it helps value estimation.
7399
74-
:param env: the environment to wrap
100+
:param env: Environment to wrap
75101
"""
76102

77103
def __init__(self, env: gym.Env) -> None:
@@ -120,9 +146,11 @@ def reset(self, **kwargs) -> np.ndarray:
120146
class MaxAndSkipEnv(gym.Wrapper):
121147
"""
122148
Return only every ``skip``-th frame (frameskipping)
149+
and return the max between the two last frames.
123150
124-
:param env: the environment
125-
:param skip: number of ``skip``-th frame
151+
:param env: Environment to wrap
152+
:param skip: Number of ``skip``-th frame
153+
The same action will be taken ``skip`` times.
126154
"""
127155

128156
def __init__(self, env: gym.Env, skip: int = 4) -> None:
@@ -159,9 +187,9 @@ def step(self, action: int) -> GymStepReturn:
159187

160188
class ClipRewardEnv(gym.RewardWrapper):
161189
"""
162-
Clips the reward to {+1, 0, -1} by its sign.
190+
Clip the reward to {+1, 0, -1} by its sign.
163191
164-
:param env: the environment
192+
:param env: Environment to wrap
165193
"""
166194

167195
def __init__(self, env: gym.Env) -> None:
@@ -182,9 +210,9 @@ class WarpFrame(gym.ObservationWrapper):
182210
Convert to grayscale and warp frames to 84x84 (default)
183211
as done in the Nature paper and later work.
184212
185-
:param env: the environment
186-
:param width:
187-
:param height:
213+
:param env: Environment to wrap
214+
:param width: New frame width
215+
:param height: New frame height
188216
"""
189217

190218
def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None:
@@ -213,20 +241,29 @@ class AtariWrapper(gym.Wrapper):
213241
214242
Specifically:
215243
216-
* NoopReset: obtain initial state by taking random number of no-ops on reset.
244+
* Noop reset: obtain initial state by taking random number of no-ops on reset.
217245
* Frame skipping: 4 by default
218246
* Max-pooling: most recent two observations
219247
* Termination signal when a life is lost.
220248
* Resize to a square image: 84x84 by default
221249
* Grayscale observation
222250
* Clip reward to {-1, 0, 1}
251+
* Sticky actions: disabled by default
252+
253+
See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/
254+
for a visual explanation.
255+
256+
.. warning::
257+
Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``.
223258
224-
:param env: gym environment
225-
:param noop_max: max number of no-ops
226-
:param frame_skip: the frequency at which the agent experiences the game.
227-
:param screen_size: resize Atari frame
228-
:param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost.
259+
:param env: Environment to wrap
260+
:param noop_max: Max number of no-ops
261+
:param frame_skip: Frequency at which the agent experiences the game.
262+
This correspond to repeating the action ``frame_skip`` times.
263+
:param screen_size: Resize Atari frame
264+
:param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost.
229265
:param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
266+
:param action_repeat_probability: Probability of repeating the last action
230267
"""
231268

232269
def __init__(
@@ -237,9 +274,15 @@ def __init__(
237274
screen_size: int = 84,
238275
terminal_on_life_loss: bool = True,
239276
clip_reward: bool = True,
277+
action_repeat_probability: float = 0.0,
240278
) -> None:
241-
env = NoopResetEnv(env, noop_max=noop_max)
242-
env = MaxAndSkipEnv(env, skip=frame_skip)
279+
if action_repeat_probability > 0.0:
280+
env = StickyActionEnv(env, action_repeat_probability)
281+
if noop_max > 0:
282+
env = NoopResetEnv(env, noop_max=noop_max)
283+
# frame_skip=1 is the same as no frame-skip (action repeat)
284+
if frame_skip > 1:
285+
env = MaxAndSkipEnv(env, skip=frame_skip)
243286
if terminal_on_life_loss:
244287
env = EpisodicLifeEnv(env)
245288
if "FIRE" in env.unwrapped.get_action_meanings():

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.8.0a2
1+
1.8.0a3

tests/test_utils.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import stable_baselines3 as sb3
1111
from stable_baselines3 import A2C
12-
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv
12+
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
1313
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
1414
from stable_baselines3.common.evaluation import evaluate_policy
1515
from stable_baselines3.common.monitor import Monitor
@@ -55,30 +55,54 @@ def test_make_vec_env_func_checker():
5555
env.close()
5656

5757

58-
@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4"])
59-
@pytest.mark.parametrize("n_envs", [1, 2])
60-
@pytest.mark.parametrize("wrapper_kwargs", [None, dict(clip_reward=False, screen_size=60)])
61-
def test_make_atari_env(env_id, n_envs, wrapper_kwargs):
62-
env = make_atari_env(env_id, n_envs, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0)
58+
# Use Asterix as it does not requires fire reset
59+
@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4", "AsterixNoFrameskip-v4"])
60+
@pytest.mark.parametrize("noop_max", [0, 10])
61+
@pytest.mark.parametrize("action_repeat_probability", [0.0, 0.25])
62+
@pytest.mark.parametrize("frame_skip", [1, 4])
63+
@pytest.mark.parametrize("screen_size", [60])
64+
@pytest.mark.parametrize("terminal_on_life_loss", [True, False])
65+
@pytest.mark.parametrize("clip_reward", [True])
66+
def test_make_atari_env(
67+
env_id, noop_max, action_repeat_probability, frame_skip, screen_size, terminal_on_life_loss, clip_reward
68+
):
69+
n_envs = 2
70+
wrapper_kwargs = {
71+
"noop_max": noop_max,
72+
"action_repeat_probability": action_repeat_probability,
73+
"frame_skip": frame_skip,
74+
"screen_size": screen_size,
75+
"terminal_on_life_loss": terminal_on_life_loss,
76+
"clip_reward": clip_reward,
77+
}
78+
venv = make_atari_env(
79+
env_id,
80+
n_envs=2,
81+
wrapper_kwargs=wrapper_kwargs,
82+
monitor_dir=None,
83+
seed=0,
84+
)
6385

64-
assert env.num_envs == n_envs
86+
assert venv.num_envs == n_envs
6587

66-
obs = env.reset()
88+
needs_fire_reset = env_id == "BreakoutNoFrameskip-v4"
89+
expected_frame_number_low = frame_skip * 2 if needs_fire_reset else 0 # FIRE - UP on reset
90+
expected_frame_number_high = expected_frame_number_low + noop_max
91+
expected_shape = (n_envs, screen_size, screen_size, 1)
6792

68-
new_obs, reward, _, _ = env.step([env.action_space.sample() for _ in range(n_envs)])
93+
obs = venv.reset()
94+
frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs]
95+
for frame_number in frame_numbers:
96+
assert expected_frame_number_low <= frame_number <= expected_frame_number_high
97+
assert obs.shape == expected_shape
6998

70-
assert obs.shape == new_obs.shape
99+
new_obs, reward, _, _ = venv.step([venv.action_space.sample() for _ in range(n_envs)])
71100

72-
# Wrapped into DummyVecEnv
73-
wrapped_atari_env = env.envs[0]
74-
if wrapper_kwargs is not None:
75-
assert obs.shape == (n_envs, 60, 60, 1)
76-
assert wrapped_atari_env.observation_space.shape == (60, 60, 1)
77-
assert not isinstance(wrapped_atari_env.env, ClipRewardEnv)
78-
else:
79-
assert obs.shape == (n_envs, 84, 84, 1)
80-
assert wrapped_atari_env.observation_space.shape == (84, 84, 1)
81-
assert isinstance(wrapped_atari_env.env, ClipRewardEnv)
101+
new_frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs]
102+
for frame_number, new_frame_number in zip(frame_numbers, new_frame_numbers):
103+
assert new_frame_number - frame_number == frame_skip
104+
assert new_obs.shape == expected_shape
105+
if clip_reward:
82106
assert np.max(np.abs(reward)) < 1.0
83107

84108

0 commit comments

Comments
 (0)