Skip to content

Commit 4fa17dc

Browse files
authored
Standardize the use of from gym import spaces (#1240)
* generalize the use of `from gym import spaces` * command line get system info * Documentation line length for doc * update changelog * add space before os plateform to avoid ref to other issue * format * get_system_info update in changelog * fix type check error * fix get system info * add comment about regex * update version
1 parent 2bb8ef5 commit 4fa17dc

34 files changed

+219
-196
lines changed

.github/ISSUE_TEMPLATE/bug_report.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ body:
5252
* Versions of any other relevant libraries
5353
5454
You can use `sb3.get_system_info()` to print relevant packages info:
55-
```python
56-
import stable_baselines3 as sb3
57-
sb3.get_system_info()
55+
```sh
56+
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
5857
```
5958
- type: checkboxes
6059
id: terms

.github/ISSUE_TEMPLATE/custom_env.yml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,27 +36,28 @@ body:
3636
```python
3737
import gym
3838
import numpy as np
39+
from gym import spaces
3940
4041
from stable_baselines3 import A2C
4142
from stable_baselines3.common.env_checker import check_env
4243
4344
4445
class CustomEnv(gym.Env):
4546
46-
def __init__(self):
47-
super().__init__()
48-
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
49-
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,))
47+
def __init__(self):
48+
super().__init__()
49+
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
50+
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
5051
51-
def reset(self):
52-
return self.observation_space.sample()
52+
def reset(self):
53+
return self.observation_space.sample()
5354
54-
def step(self, action):
55-
obs = self.observation_space.sample()
56-
reward = 1.0
57-
done = False
58-
info = {}
59-
return obs, reward, done, info
55+
def step(self, action):
56+
obs = self.observation_space.sample()
57+
reward = 1.0
58+
done = False
59+
info = {}
60+
return obs, reward, done, info
6061
6162
env = CustomEnv()
6263
check_env(env)
@@ -86,9 +87,8 @@ body:
8687
* Versions of any other relevant libraries
8788
8889
You can use `sb3.get_system_info()` to print relevant packages info:
89-
```python
90-
import stable_baselines3 as sb3
91-
sb3.get_system_info()
90+
```sh
91+
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
9292
```
9393
- type: checkboxes
9494
id: terms

CONTRIBUTING.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ pip install -e .[docs,tests,extra]
3838

3939
## Codestyle
4040

41-
We are using [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
41+
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
42+
For the documentation, we use the default line length of 88 characters per line.
4243

4344
**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
4445

docs/guide/custom_env.rst

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,17 @@ That is to say, your environment must implement the following methods (and inher
2727
.. code-block:: python
2828
2929
import gym
30+
import numpy as np
3031
from gym import spaces
3132
33+
3234
class CustomEnv(gym.Env):
33-
"""Custom Environment that follows gym interface"""
35+
"""Custom Environment that follows gym interface."""
36+
3437
metadata = {"render.modes": ["human"]}
3538
3639
def __init__(self, arg1, arg2, ...):
37-
super(CustomEnv, self).__init__()
40+
super().__init__()
3841
# Define action and observation space
3942
# They must be gym.spaces objects
4043
# Example when using discrete actions:
@@ -46,12 +49,15 @@ That is to say, your environment must implement the following methods (and inher
4649
def step(self, action):
4750
...
4851
return observation, reward, done, info
52+
4953
def reset(self):
5054
...
5155
return observation # reward, done, info can't be included
56+
5257
def render(self, mode="human"):
5358
...
54-
def close (self):
59+
60+
def close(self):
5561
...
5662
5763

docs/guide/custom_policy.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
125125

126126
.. code-block:: python
127127
128-
import gym
129128
import torch as th
130129
import torch.nn as nn
130+
from gym import spaces
131131
132132
from stable_baselines3 import PPO
133133
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
@@ -140,7 +140,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
140140
This corresponds to the number of unit for the last layer.
141141
"""
142142
143-
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
143+
def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
144144
super().__init__(observation_space, features_dim)
145145
# We assume CxHxW images (channels first)
146146
# Re-ordering will be done by pre-preprocessing or wrapper
@@ -199,7 +199,7 @@ downsampling and "vector" with a single linear layer.
199199
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
200200
201201
class CustomCombinedExtractor(BaseFeaturesExtractor):
202-
def __init__(self, observation_space: gym.spaces.Dict):
202+
def __init__(self, observation_space: spaces.Dict):
203203
# We do not know features-dim here before going over all the items,
204204
# so put something dummy for now. PyTorch requires calling
205205
# nn.Module.__init__ before adding modules
@@ -310,7 +310,7 @@ If your task requires even more granular control over the policy/value architect
310310
311311
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
312312
313-
import gym
313+
from gym import spaces
314314
import torch as th
315315
from torch import nn
316316
@@ -367,8 +367,8 @@ If your task requires even more granular control over the policy/value architect
367367
class CustomActorCriticPolicy(ActorCriticPolicy):
368368
def __init__(
369369
self,
370-
observation_space: gym.spaces.Space,
371-
action_space: gym.spaces.Space,
370+
observation_space: spaces.Space,
371+
action_space: spaces.Space,
372372
lr_schedule: Callable[[float], float],
373373
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
374374
activation_fn: Type[nn.Module] = nn.Tanh,

docs/misc/changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Release 1.7.0a10 (WIP)
7+
Release 1.7.0a11 (WIP)
88
--------------------------
99

1010
.. note::
@@ -71,6 +71,8 @@ Others:
7171
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
7272
- Set tensors construction directly on the device (~8% speed boost on GPU)
7373
- Monkey-patched ``np.bool = bool`` so gym 0.21 is compatible with NumPy 1.24+
74+
- Standardized the use of ``from gym import spaces``
75+
- Modified ``get_system_info`` to avoid issue linked to copy-pasting on GitHub issue
7476

7577
Documentation:
7678
^^^^^^^^^^^^^^

stable_baselines3/common/base_class.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import gym
1212
import numpy as np
1313
import torch as th
14+
from gym import spaces
1415

1516
from stable_baselines3.common import utils
1617
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
@@ -101,7 +102,7 @@ def __init__(
101102
seed: Optional[int] = None,
102103
use_sde: bool = False,
103104
sde_sample_freq: int = -1,
104-
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
105+
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
105106
):
106107
if isinstance(policy, str):
107108
self.policy_class = self._get_policy_from_name(policy)
@@ -117,8 +118,8 @@ def __init__(
117118
self._vec_normalize_env = unwrap_vec_normalize(env)
118119
self.verbose = verbose
119120
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
120-
self.observation_space = None # type: Optional[gym.spaces.Space]
121-
self.action_space = None # type: Optional[gym.spaces.Space]
121+
self.observation_space = None # type: Optional[spaces.Space]
122+
self.action_space = None # type: Optional[spaces.Space]
122123
self.n_envs = None
123124
self.num_timesteps = 0
124125
# Used for updating schedules
@@ -175,13 +176,13 @@ def __init__(
175176
)
176177

177178
# Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
178-
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict):
179+
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, spaces.Dict):
179180
raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
180181

181-
if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
182+
if self.use_sde and not isinstance(self.action_space, spaces.Box):
182183
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
183184

184-
if isinstance(self.action_space, gym.spaces.Box):
185+
if isinstance(self.action_space, spaces.Box):
185186
assert np.all(
186187
np.isfinite(np.array([self.action_space.low, self.action_space.high]))
187188
), "Continuous action space must have a finite lower and upper bound"
@@ -212,7 +213,7 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve
212213

213214
if not is_vecenv_wrapped(env, VecTransposeImage):
214215
wrap_with_vectranspose = False
215-
if isinstance(env.observation_space, gym.spaces.Dict):
216+
if isinstance(env.observation_space, spaces.Dict):
216217
# If even one of the keys is a image-space in need of transpose, apply transpose
217218
# If the image spaces are not consistent (for instance one is channel first,
218219
# the other channel last), VecTransposeImage will throw an error

stable_baselines3/common/distributions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from abc import ABC, abstractmethod
44
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
55

6-
import gym
76
import numpy as np
87
import torch as th
98
from gym import spaces
@@ -659,7 +658,7 @@ def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
659658

660659

661660
def make_proba_distribution(
662-
action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
661+
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
663662
) -> Distribution:
664663
"""
665664
Return an instance of Distribution for the correct type of action space

stable_baselines3/common/env_checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
266266

267267
def _check_spaces(env: gym.Env) -> None:
268268
"""
269-
Check that the observation and action spaces are defined and inherit from gym.spaces.Space. For
269+
Check that the observation and action spaces are defined and inherit from spaces.Space. For
270270
envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check
271271
the observation space is gym.spaces.Dict
272272
"""

stable_baselines3/common/envs/identity_env.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
from typing import Optional, Union
1+
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union
22

3+
import gym
34
import numpy as np
4-
from gym import Env, Space
5-
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
5+
from gym import spaces
66

77
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
88

9+
T = TypeVar("T", int, np.ndarray)
910

10-
class IdentityEnv(Env):
11-
def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100):
11+
12+
class IdentityEnv(gym.Env, Generic[T]):
13+
def __init__(self, dim: Optional[int] = None, space: Optional[spaces.Space] = None, ep_length: int = 100):
1214
"""
1315
Identity environment for testing purposes
1416
@@ -22,7 +24,7 @@ def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_
2224
if space is None:
2325
if dim is None:
2426
dim = 1
25-
space = Discrete(dim)
27+
space = spaces.Discrete(dim)
2628
else:
2729
assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed"
2830

@@ -32,13 +34,13 @@ def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_
3234
self.num_resets = -1 # Becomes 0 after __init__ exits.
3335
self.reset()
3436

35-
def reset(self) -> GymObs:
37+
def reset(self) -> T:
3638
self.current_step = 0
3739
self.num_resets += 1
3840
self._choose_next_state()
3941
return self.state
4042

41-
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
43+
def step(self, action: T) -> Tuple[T, float, bool, Dict[str, Any]]:
4244
reward = self._get_reward(action)
4345
self._choose_next_state()
4446
self.current_step += 1
@@ -48,14 +50,14 @@ def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
4850
def _choose_next_state(self) -> None:
4951
self.state = self.action_space.sample()
5052

51-
def _get_reward(self, action: Union[int, np.ndarray]) -> float:
53+
def _get_reward(self, action: T) -> float:
5254
return 1.0 if np.all(self.state == action) else 0.0
5355

5456
def render(self, mode: str = "human") -> None:
5557
pass
5658

5759

58-
class IdentityEnvBox(IdentityEnv):
60+
class IdentityEnvBox(IdentityEnv[np.ndarray]):
5961
def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100):
6062
"""
6163
Identity environment for testing purposes
@@ -65,7 +67,7 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l
6567
:param eps: the epsilon bound for correct value
6668
:param ep_length: the length of each episode in timesteps
6769
"""
68-
space = Box(low=low, high=high, shape=(1,), dtype=np.float32)
70+
space = spaces.Box(low=low, high=high, shape=(1,), dtype=np.float32)
6971
super().__init__(ep_length=ep_length, space=space)
7072
self.eps = eps
7173

@@ -80,31 +82,31 @@ def _get_reward(self, action: np.ndarray) -> float:
8082
return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0
8183

8284

83-
class IdentityEnvMultiDiscrete(IdentityEnv):
85+
class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]):
8486
def __init__(self, dim: int = 1, ep_length: int = 100):
8587
"""
8688
Identity environment for testing purposes
8789
8890
:param dim: the size of the dimensions you want to learn
8991
:param ep_length: the length of each episode in timesteps
9092
"""
91-
space = MultiDiscrete([dim, dim])
93+
space = spaces.MultiDiscrete([dim, dim])
9294
super().__init__(ep_length=ep_length, space=space)
9395

9496

95-
class IdentityEnvMultiBinary(IdentityEnv):
97+
class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]):
9698
def __init__(self, dim: int = 1, ep_length: int = 100):
9799
"""
98100
Identity environment for testing purposes
99101
100102
:param dim: the size of the dimensions you want to learn
101103
:param ep_length: the length of each episode in timesteps
102104
"""
103-
space = MultiBinary(dim)
105+
space = spaces.MultiBinary(dim)
104106
super().__init__(ep_length=ep_length, space=space)
105107

106108

107-
class FakeImageEnv(Env):
109+
class FakeImageEnv(gym.Env):
108110
"""
109111
Fake image environment for testing purposes, it mimics Atari games.
110112
@@ -128,11 +130,11 @@ def __init__(
128130
self.observation_shape = (screen_height, screen_width, n_channels)
129131
if channel_first:
130132
self.observation_shape = (n_channels, screen_height, screen_width)
131-
self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
133+
self.observation_space = spaces.Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
132134
if discrete:
133-
self.action_space = Discrete(action_dim)
135+
self.action_space = spaces.Discrete(action_dim)
134136
else:
135-
self.action_space = Box(low=-1, high=1, shape=(5,), dtype=np.float32)
137+
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
136138
self.ep_length = 10
137139
self.current_step = 0
138140

0 commit comments

Comments
 (0)