Skip to content

Commit 2e4a450

Browse files
qgallouedecaraffin
andauthored
Refactor observation stacking (#1238)
* refactor stacking obs * Improve docstring * remove all StackedDictObservations * Update tests and make stacked obs clearer * Fix type check * fix stacked_observation_space * undo init change, deprecate StackedDictObservations * deprecate stack_observation_space * type hints * ignore pytype errors * undo vecenv doc change * Deprecation warning in StackedDictObs doctstring * Fix vec_env.rst * Fix __all__ sorting * fix pytype ignore statement * Update docstring * stack * Remove n_stack * Update changelog * Simplify code * Rename test file * Re-use variable for shift * Fix doc build * Remove pytype comment * Disable pytype error --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 411ff69 commit 2e4a450

File tree

8 files changed

+459
-234
lines changed

8 files changed

+459
-234
lines changed

docs/guide/vec_envs.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,6 @@ StackedObservations
122122
.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedObservations
123123
:members:
124124

125-
StackedDictObservations
126-
~~~~~~~~~~~~~~~~~~~~~~~
127-
128-
.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedDictObservations
129-
:members:
130-
131125
VecNormalize
132126
~~~~~~~~~~~~
133127

docs/misc/changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ Changelog
44
==========
55

66

7-
Release 1.8.0a3 (WIP)
7+
Release 1.8.0a4 (WIP)
88
--------------------------
99

1010

1111
Breaking Changes:
1212
^^^^^^^^^^^^^^^^^
1313
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
14+
- Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed)
1415

1516
New Features:
1617
^^^^^^^^^^^^^
@@ -36,6 +37,7 @@ Others:
3637
- Fixed ``tests/test_tensorboard.py`` type hint
3738
- Fixed ``tests/test_vec_normalize.py`` type hint
3839
- Fixed ``stable_baselines3/common/monitor.py`` type hint
40+
- Added tests for StackedObservations
3941

4042
Documentation:
4143
^^^^^^^^^^^^^^

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ exclude = (?x)(
4848
| stable_baselines3/common/vec_env/__init__.py$
4949
| stable_baselines3/common/vec_env/base_vec_env.py$
5050
| stable_baselines3/common/vec_env/dummy_vec_env.py$
51-
| stable_baselines3/common/vec_env/stacked_observations.py$
5251
| stable_baselines3/common/vec_env/subproc_vec_env.py$
5352
| stable_baselines3/common/vec_env/util.py$
5453
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$

stable_baselines3/common/vec_env/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
66
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
7-
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
7+
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
88
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
99
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
1010
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
@@ -78,7 +78,6 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
7878
"VecEnv",
7979
"VecEnvWrapper",
8080
"DummyVecEnv",
81-
"StackedDictObservations",
8281
"StackedObservations",
8382
"SubprocVecEnv",
8483
"VecCheckNan",

stable_baselines3/common/vec_env/stacked_observations.py

Lines changed: 128 additions & 188 deletions
Large diffs are not rendered by default.
Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,40 @@
1-
from typing import Any, Dict, List, Optional, Tuple, Union
1+
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
22

33
import numpy as np
44
from gym import spaces
55

66
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
7-
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
7+
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
88

99

1010
class VecFrameStack(VecEnvWrapper):
1111
"""
1212
Frame stacking wrapper for vectorized environment. Designed for image observations.
1313
14-
Uses the StackedObservations class, or StackedDictObservations depending on the observations space
15-
16-
:param venv: the vectorized environment to wrap
14+
:param venv: Vectorized environment to wrap
1715
:param n_stack: Number of frames to stack
1816
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
1917
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
2018
Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
2119
"""
2220

23-
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None):
24-
self.venv = venv
25-
self.n_stack = n_stack
26-
27-
wrapped_obs_space = venv.observation_space
28-
29-
if isinstance(wrapped_obs_space, spaces.Box):
30-
assert not isinstance(
31-
channels_order, dict
32-
), f"Expected None or string for channels_order but received {channels_order}"
33-
self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
34-
35-
elif isinstance(wrapped_obs_space, spaces.Dict):
36-
self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
37-
38-
else:
39-
raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces")
21+
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Mapping[str, str]]] = None) -> None:
22+
assert isinstance(
23+
venv.observation_space, (spaces.Box, spaces.Dict)
24+
), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces"
4025

41-
observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space)
42-
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
26+
self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order)
27+
observation_space = self.stacked_obs.stacked_observation_space
28+
super().__init__(venv, observation_space=observation_space)
4329

4430
def step_wait(
4531
self,
4632
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
4733
observations, rewards, dones, infos = self.venv.step_wait()
48-
49-
observations, infos = self.stackedobs.update(observations, dones, infos)
50-
34+
observations, infos = self.stacked_obs.update(observations, dones, infos)
5135
return observations, rewards, dones, infos
5236

5337
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
54-
"""
55-
Reset all environments
56-
"""
5738
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
58-
59-
observation = self.stackedobs.reset(observation)
39+
observation = self.stacked_obs.reset(observation)
6040
return observation
61-
62-
def close(self) -> None:
63-
self.venv.close()

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.8.0a3
1+
1.8.0a4

0 commit comments

Comments
 (0)