@@ -19,13 +19,27 @@ class UniformSampler(Sampler):
19
19
"""
20
20
21
21
def __init__ (
22
- self , min_value : Union [int , float ], max_value : Union [int , float ], ** kwargs
22
+ self ,
23
+ min_value : Union [int , float ],
24
+ max_value : Union [int , float ],
25
+ seed : Optional [int ] = None ,
26
+ ** kwargs
23
27
) -> None :
28
+ """
29
+ :param min_value: minimum value of the range to be sampled uniformly from
30
+ :param max_value: maximum value of the range to be sampled uniformly from
31
+ :param seed: Random seed used for making draws from the uniform sampler
32
+ """
24
33
self .min_value = min_value
25
34
self .max_value = max_value
35
+ # Draw from random state to allow for consistent reset parameter draw for a seed
36
+ self .random_state = np .random .RandomState (seed )
26
37
27
38
def sample_parameter (self ) -> float :
28
- return np .random .uniform (self .min_value , self .max_value )
39
+ """
40
+ Draws and returns a sample from the specified interval
41
+ """
42
+ return self .random_state .uniform (self .min_value , self .max_value )
29
43
30
44
31
45
class MultiRangeUniformSampler (Sampler ):
@@ -36,19 +50,33 @@ class MultiRangeUniformSampler(Sampler):
36
50
it proceeds to pick a value uniformly in that range.
37
51
"""
38
52
39
- def __init__ (self , intervals : List [List [Union [int , float ]]], ** kwargs ) -> None :
53
+ def __init__ (
54
+ self ,
55
+ intervals : List [List [Union [int , float ]]],
56
+ seed : Optional [int ] = None ,
57
+ ** kwargs
58
+ ) -> None :
59
+ """
60
+ :param intervals: List of intervals to draw uniform samples from
61
+ :param seed: Random seed used for making uniform draws from the specified intervals
62
+ """
40
63
self .intervals = intervals
41
64
# Measure the length of the intervals
42
65
interval_lengths = [abs (x [1 ] - x [0 ]) for x in self .intervals ]
43
66
cum_interval_length = sum (interval_lengths )
44
67
# Assign weights to an interval proportionate to the interval size
45
68
self .interval_weights = [x / cum_interval_length for x in interval_lengths ]
69
+ # Draw from random state to allow for consistent reset parameter draw for a seed
70
+ self .random_state = np .random .RandomState (seed )
46
71
47
72
def sample_parameter (self ) -> float :
73
+ """
74
+ Selects an interval to pick and then draws a uniform sample from the picked interval
75
+ """
48
76
cur_min , cur_max = self .intervals [
49
- np . random .choice (len (self .intervals ), p = self .interval_weights )
77
+ self . random_state .choice (len (self .intervals ), p = self .interval_weights )
50
78
]
51
- return np . random .uniform (cur_min , cur_max )
79
+ return self . random_state .uniform (cur_min , cur_max )
52
80
53
81
54
82
class GaussianSampler (Sampler ):
@@ -58,13 +86,27 @@ class GaussianSampler(Sampler):
58
86
"""
59
87
60
88
def __init__ (
61
- self , mean : Union [float , int ], st_dev : Union [float , int ], ** kwargs
89
+ self ,
90
+ mean : Union [float , int ],
91
+ st_dev : Union [float , int ],
92
+ seed : Optional [int ] = None ,
93
+ ** kwargs
62
94
) -> None :
95
+ """
96
+ :param mean: Specifies the mean of the gaussian distribution to draw from
97
+ :param st_dev: Specifies the standard devation of the gaussian distribution to draw from
98
+ :param seed: Random seed used for making gaussian draws from the sample
99
+ """
63
100
self .mean = mean
64
101
self .st_dev = st_dev
102
+ # Draw from random state to allow for consistent reset parameter draw for a seed
103
+ self .random_state = np .random .RandomState (seed )
65
104
66
105
def sample_parameter (self ) -> float :
67
- return np .random .normal (self .mean , self .st_dev )
106
+ """
107
+ Returns a draw from the specified Gaussian distribution
108
+ """
109
+ return self .random_state .normal (self .mean , self .st_dev )
68
110
69
111
70
112
class SamplerFactory :
@@ -81,17 +123,31 @@ class SamplerFactory:
81
123
82
124
@staticmethod
83
125
def register_sampler (name : str , sampler_cls : Type [Sampler ]) -> None :
126
+ """
127
+ Registers the sampe in the Sampler Factory to be used later
128
+ :param name: String name to set as key for the sampler_cls in the factory
129
+ :param sampler_cls: Sampler object to associate to the name in the factory
130
+ """
84
131
SamplerFactory .NAME_TO_CLASS [name ] = sampler_cls
85
132
86
133
@staticmethod
87
- def init_sampler_class (name : str , params : Dict [str , Any ]):
134
+ def init_sampler_class (
135
+ name : str , params : Dict [str , Any ], seed : Optional [int ] = None
136
+ ) -> Sampler :
137
+ """
138
+ Initializes the sampler class associated with the name with the params
139
+ :param name: Name of the sampler in the factory to initialize
140
+ :param params: Parameters associated to the sampler attached to the name
141
+ :param seed: Random seed to be used to set deterministic random draws for the sampler
142
+ """
88
143
if name not in SamplerFactory .NAME_TO_CLASS :
89
144
raise SamplerException (
90
145
name + " sampler is not registered in the SamplerFactory."
91
146
" Use the register_sample method to register the string"
92
147
" associated to your sampler in the SamplerFactory."
93
148
)
94
149
sampler_cls = SamplerFactory .NAME_TO_CLASS [name ]
150
+ params ["seed" ] = seed
95
151
try :
96
152
return sampler_cls (** params )
97
153
except TypeError :
@@ -103,7 +159,13 @@ def init_sampler_class(name: str, params: Dict[str, Any]):
103
159
104
160
105
161
class SamplerManager :
106
- def __init__ (self , reset_param_dict : Dict [str , Any ]) -> None :
162
+ def __init__ (
163
+ self , reset_param_dict : Dict [str , Any ], seed : Optional [int ] = None
164
+ ) -> None :
165
+ """
166
+ :param reset_param_dict: Arguments needed for initializing the samplers
167
+ :param seed: Random seed to be used for drawing samples from the samplers
168
+ """
107
169
self .reset_param_dict = reset_param_dict if reset_param_dict else {}
108
170
assert isinstance (self .reset_param_dict , dict )
109
171
self .samplers : Dict [str , Sampler ] = {}
@@ -116,7 +178,7 @@ def __init__(self, reset_param_dict: Dict[str, Any]) -> None:
116
178
)
117
179
sampler_name = cur_param_dict .pop ("sampler-type" )
118
180
param_sampler = SamplerFactory .init_sampler_class (
119
- sampler_name , cur_param_dict
181
+ sampler_name , cur_param_dict , seed
120
182
)
121
183
122
184
self .samplers [param_name ] = param_sampler
@@ -128,6 +190,10 @@ def is_empty(self) -> bool:
128
190
return not bool (self .samplers )
129
191
130
192
def sample_all (self ) -> Dict [str , float ]:
193
+ """
194
+ Loop over all samplers and draw a sample from each one for generating
195
+ next set of reset parameter values.
196
+ """
131
197
res = {}
132
198
for param_name , param_sampler in list (self .samplers .items ()):
133
199
res [param_name ] = param_sampler .sample_parameter ()
0 commit comments