|
1 |
| -from typing import Any, Dict, List, Optional, Tuple, Union |
| 1 | +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union |
2 | 2 |
|
3 | 3 | import numpy as np
|
4 | 4 | from gym import spaces
|
5 | 5 |
|
6 | 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
|
7 |
| -from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations |
| 7 | +from stable_baselines3.common.vec_env.stacked_observations import StackedObservations |
8 | 8 |
|
9 | 9 |
|
10 | 10 | class VecFrameStack(VecEnvWrapper):
|
11 | 11 | """
|
12 | 12 | Frame stacking wrapper for vectorized environment. Designed for image observations.
|
13 | 13 |
|
14 |
| - Uses the StackedObservations class, or StackedDictObservations depending on the observations space |
15 |
| -
|
16 |
| - :param venv: the vectorized environment to wrap |
| 14 | + :param venv: Vectorized environment to wrap |
17 | 15 | :param n_stack: Number of frames to stack
|
18 | 16 | :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
|
19 | 17 | If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
|
20 | 18 | Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
|
21 | 19 | """
|
22 | 20 |
|
23 |
| - def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None): |
24 |
| - self.venv = venv |
25 |
| - self.n_stack = n_stack |
26 |
| - |
27 |
| - wrapped_obs_space = venv.observation_space |
28 |
| - |
29 |
| - if isinstance(wrapped_obs_space, spaces.Box): |
30 |
| - assert not isinstance( |
31 |
| - channels_order, dict |
32 |
| - ), f"Expected None or string for channels_order but received {channels_order}" |
33 |
| - self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) |
34 |
| - |
35 |
| - elif isinstance(wrapped_obs_space, spaces.Dict): |
36 |
| - self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) |
37 |
| - |
38 |
| - else: |
39 |
| - raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces") |
| 21 | + def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Mapping[str, str]]] = None) -> None: |
| 22 | + assert isinstance( |
| 23 | + venv.observation_space, (spaces.Box, spaces.Dict) |
| 24 | + ), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces" |
40 | 25 |
|
41 |
| - observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space) |
42 |
| - VecEnvWrapper.__init__(self, venv, observation_space=observation_space) |
| 26 | + self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order) |
| 27 | + observation_space = self.stacked_obs.stacked_observation_space |
| 28 | + super().__init__(venv, observation_space=observation_space) |
43 | 29 |
|
44 | 30 | def step_wait(
|
45 | 31 | self,
|
46 | 32 | ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
|
47 | 33 | observations, rewards, dones, infos = self.venv.step_wait()
|
48 |
| - |
49 |
| - observations, infos = self.stackedobs.update(observations, dones, infos) |
50 |
| - |
| 34 | + observations, infos = self.stacked_obs.update(observations, dones, infos) |
51 | 35 | return observations, rewards, dones, infos
|
52 | 36 |
|
53 | 37 | def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
54 |
| - """ |
55 |
| - Reset all environments |
56 |
| - """ |
57 | 38 | observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
|
58 |
| - |
59 |
| - observation = self.stackedobs.reset(observation) |
| 39 | + observation = self.stacked_obs.reset(observation) |
60 | 40 | return observation
|
61 |
| - |
62 |
| - def close(self) -> None: |
63 |
| - self.venv.close() |
0 commit comments