Skip to content

Commit 489b1fd

Browse files
sidney-tioqgallouedecaraffin
authored
Add the argument dtype (default to float32) to the noise (#1301)
* Fixed noise to return float32 * Updated changelog * Fixed test to use numpy arrays instead of python floats * Sorted imports for tests * Added dtype to constructor * Removed dtype parameter for VectorizedActionNoise * __init__ -> None; Capitalize and period in docstring when needed; fix dtype type hint; dtype in docstring * fix dtype type hint * Update version * Clarify changelog [skip ci] * empty commit to run ci * Update docs/misc/changelog.rst --------- Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 2e4a450 commit 489b1fd

File tree

3 files changed

+30
-20
lines changed

3 files changed

+30
-20
lines changed

docs/misc/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ New Features:
2727
Bug Fixes:
2828
^^^^^^^^^^
2929
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
30+
- Added the argument ``dtype`` (default to ``float32``) to the noise for consistency with gym action (@sidney-tio)
3031
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
3132

3233
Deprecations:

stable_baselines3/common/noise.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Iterable, List, Optional
44

55
import numpy as np
6+
from numpy.typing import DTypeLike
67

78

89
class ActionNoise(ABC):
@@ -15,7 +16,7 @@ def __init__(self) -> None:
1516

1617
def reset(self) -> None:
1718
"""
18-
call end of episode reset for the noise
19+
Call end of episode reset for the noise
1920
"""
2021
pass
2122

@@ -26,19 +27,21 @@ def __call__(self) -> np.ndarray:
2627

2728
class NormalActionNoise(ActionNoise):
2829
"""
29-
A Gaussian action noise
30+
A Gaussian action noise.
3031
31-
:param mean: the mean value of the noise
32-
:param sigma: the scale of the noise (std here)
32+
:param mean: Mean value of the noise
33+
:param sigma: Scale of the noise (std here)
34+
:param dtype: Type of the output noise
3335
"""
3436

35-
def __init__(self, mean: np.ndarray, sigma: np.ndarray):
37+
def __init__(self, mean: np.ndarray, sigma: np.ndarray, dtype: DTypeLike = np.float32) -> None:
3638
self._mu = mean
3739
self._sigma = sigma
40+
self._dtype = dtype
3841
super().__init__()
3942

4043
def __call__(self) -> np.ndarray:
41-
return np.random.normal(self._mu, self._sigma)
44+
return np.random.normal(self._mu, self._sigma).astype(self._dtype)
4245

4346
def __repr__(self) -> str:
4447
return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
@@ -50,11 +53,12 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
5053
5154
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
5255
53-
:param mean: the mean of the noise
54-
:param sigma: the scale of the noise
55-
:param theta: the rate of mean reversion
56-
:param dt: the timestep for the noise
57-
:param initial_noise: the initial value for the noise output, (if None: 0)
56+
:param mean: Mean of the noise
57+
:param sigma: Scale of the noise
58+
:param theta: Rate of mean reversion
59+
:param dt: Timestep for the noise
60+
:param initial_noise: Initial value for the noise output, (if None: 0)
61+
:param dtype: Type of the output noise
5862
"""
5963

6064
def __init__(
@@ -64,11 +68,13 @@ def __init__(
6468
theta: float = 0.15,
6569
dt: float = 1e-2,
6670
initial_noise: Optional[np.ndarray] = None,
67-
):
71+
dtype: DTypeLike = np.float32,
72+
) -> None:
6873
self._theta = theta
6974
self._mu = mean
7075
self._sigma = sigma
7176
self._dt = dt
77+
self._dtype = dtype
7278
self.initial_noise = initial_noise
7379
self.noise_prev = np.zeros_like(self._mu)
7480
self.reset()
@@ -81,7 +87,7 @@ def __call__(self) -> np.ndarray:
8187
+ self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
8288
)
8389
self.noise_prev = noise
84-
return noise
90+
return noise.astype(self._dtype)
8591

8692
def reset(self) -> None:
8793
"""
@@ -97,11 +103,11 @@ class VectorizedActionNoise(ActionNoise):
97103
"""
98104
A Vectorized action noise for parallel environments.
99105
100-
:param base_noise: ActionNoise The noise generator to use
101-
:param n_envs: The number of parallel environments
106+
:param base_noise: Noise generator to use
107+
:param n_envs: Number of parallel environments
102108
"""
103109

104-
def __init__(self, base_noise: ActionNoise, n_envs: int):
110+
def __init__(self, base_noise: ActionNoise, n_envs: int) -> None:
105111
try:
106112
self.n_envs = int(n_envs)
107113
assert self.n_envs > 0
@@ -113,9 +119,9 @@ def __init__(self, base_noise: ActionNoise, n_envs: int):
113119

114120
def reset(self, indices: Optional[Iterable[int]] = None) -> None:
115121
"""
116-
Reset all the noise processes, or those listed in indices
122+
Reset all the noise processes, or those listed in indices.
117123
118-
:param indices: Optional[Iterable[int]] The indices to reset. Default: None.
124+
:param indices: The indices to reset. Default: None.
119125
If the parameter is None, then all processes are reset to their initial position.
120126
"""
121127
if indices is None:
@@ -129,7 +135,7 @@ def __repr__(self) -> str:
129135

130136
def __call__(self) -> np.ndarray:
131137
"""
132-
Generate and stack the action noise from each noise object
138+
Generate and stack the action noise from each noise object.
133139
"""
134140
noise = np.stack([noise() for noise in self.noises])
135141
return noise

tests/test_deterministic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23

34
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
@@ -15,7 +16,9 @@ def test_deterministic_training_common(algo):
1516
kwargs = {"policy_kwargs": dict(net_arch=[64])}
1617
env_id = "Pendulum-v1"
1718
if algo in [TD3, SAC]:
18-
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
19+
kwargs.update(
20+
{"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4}
21+
)
1922
else:
2023
if algo == DQN:
2124
env_id = "CartPole-v1"

0 commit comments

Comments
 (0)