3
3
from typing import Iterable , List , Optional
4
4
5
5
import numpy as np
6
+ from numpy .typing import DTypeLike
6
7
7
8
8
9
class ActionNoise (ABC ):
@@ -15,7 +16,7 @@ def __init__(self) -> None:
15
16
16
17
def reset (self ) -> None :
17
18
"""
18
- call end of episode reset for the noise
19
+ Call end of episode reset for the noise
19
20
"""
20
21
pass
21
22
@@ -26,19 +27,21 @@ def __call__(self) -> np.ndarray:
26
27
27
28
class NormalActionNoise (ActionNoise ):
28
29
"""
29
- A Gaussian action noise
30
+ A Gaussian action noise.
30
31
31
- :param mean: the mean value of the noise
32
- :param sigma: the scale of the noise (std here)
32
+ :param mean: Mean value of the noise
33
+ :param sigma: Scale of the noise (std here)
34
+ :param dtype: Type of the output noise
33
35
"""
34
36
35
- def __init__ (self , mean : np .ndarray , sigma : np .ndarray ) :
37
+ def __init__ (self , mean : np .ndarray , sigma : np .ndarray , dtype : DTypeLike = np . float32 ) -> None :
36
38
self ._mu = mean
37
39
self ._sigma = sigma
40
+ self ._dtype = dtype
38
41
super ().__init__ ()
39
42
40
43
def __call__ (self ) -> np .ndarray :
41
- return np .random .normal (self ._mu , self ._sigma )
44
+ return np .random .normal (self ._mu , self ._sigma ). astype ( self . _dtype )
42
45
43
46
def __repr__ (self ) -> str :
44
47
return f"NormalActionNoise(mu={ self ._mu } , sigma={ self ._sigma } )"
@@ -50,11 +53,12 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
50
53
51
54
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
52
55
53
- :param mean: the mean of the noise
54
- :param sigma: the scale of the noise
55
- :param theta: the rate of mean reversion
56
- :param dt: the timestep for the noise
57
- :param initial_noise: the initial value for the noise output, (if None: 0)
56
+ :param mean: Mean of the noise
57
+ :param sigma: Scale of the noise
58
+ :param theta: Rate of mean reversion
59
+ :param dt: Timestep for the noise
60
+ :param initial_noise: Initial value for the noise output, (if None: 0)
61
+ :param dtype: Type of the output noise
58
62
"""
59
63
60
64
def __init__ (
@@ -64,11 +68,13 @@ def __init__(
64
68
theta : float = 0.15 ,
65
69
dt : float = 1e-2 ,
66
70
initial_noise : Optional [np .ndarray ] = None ,
67
- ):
71
+ dtype : DTypeLike = np .float32 ,
72
+ ) -> None :
68
73
self ._theta = theta
69
74
self ._mu = mean
70
75
self ._sigma = sigma
71
76
self ._dt = dt
77
+ self ._dtype = dtype
72
78
self .initial_noise = initial_noise
73
79
self .noise_prev = np .zeros_like (self ._mu )
74
80
self .reset ()
@@ -81,7 +87,7 @@ def __call__(self) -> np.ndarray:
81
87
+ self ._sigma * np .sqrt (self ._dt ) * np .random .normal (size = self ._mu .shape )
82
88
)
83
89
self .noise_prev = noise
84
- return noise
90
+ return noise . astype ( self . _dtype )
85
91
86
92
def reset (self ) -> None :
87
93
"""
@@ -97,11 +103,11 @@ class VectorizedActionNoise(ActionNoise):
97
103
"""
98
104
A Vectorized action noise for parallel environments.
99
105
100
- :param base_noise: ActionNoise The noise generator to use
101
- :param n_envs: The number of parallel environments
106
+ :param base_noise: Noise generator to use
107
+ :param n_envs: Number of parallel environments
102
108
"""
103
109
104
- def __init__ (self , base_noise : ActionNoise , n_envs : int ):
110
+ def __init__ (self , base_noise : ActionNoise , n_envs : int ) -> None :
105
111
try :
106
112
self .n_envs = int (n_envs )
107
113
assert self .n_envs > 0
@@ -113,9 +119,9 @@ def __init__(self, base_noise: ActionNoise, n_envs: int):
113
119
114
120
def reset (self , indices : Optional [Iterable [int ]] = None ) -> None :
115
121
"""
116
- Reset all the noise processes, or those listed in indices
122
+ Reset all the noise processes, or those listed in indices.
117
123
118
- :param indices: Optional[Iterable[int]] The indices to reset. Default: None.
124
+ :param indices: The indices to reset. Default: None.
119
125
If the parameter is None, then all processes are reset to their initial position.
120
126
"""
121
127
if indices is None :
@@ -129,7 +135,7 @@ def __repr__(self) -> str:
129
135
130
136
def __call__ (self ) -> np .ndarray :
131
137
"""
132
- Generate and stack the action noise from each noise object
138
+ Generate and stack the action noise from each noise object.
133
139
"""
134
140
noise = np .stack ([noise () for noise in self .noises ])
135
141
return noise
0 commit comments