Skip to content

Commit 2a76490

Browse files
sankalp04Ervin T
authored andcommitted
Change samplers to use random state to allow consistency in reset par… (#2398)
* Change samplers to use random state to allow consistency in reset parameter draws for a specified seed
1 parent b243314 commit 2a76490

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

ml-agents-envs/mlagents/envs/sampler_class.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,27 @@ class UniformSampler(Sampler):
1919
"""
2020

2121
def __init__(
22-
self, min_value: Union[int, float], max_value: Union[int, float], **kwargs
22+
self,
23+
min_value: Union[int, float],
24+
max_value: Union[int, float],
25+
seed: Optional[int] = None,
26+
**kwargs
2327
) -> None:
28+
"""
29+
:param min_value: minimum value of the range to be sampled uniformly from
30+
:param max_value: maximum value of the range to be sampled uniformly from
31+
:param seed: Random seed used for making draws from the uniform sampler
32+
"""
2433
self.min_value = min_value
2534
self.max_value = max_value
35+
# Draw from random state to allow for consistent reset parameter draw for a seed
36+
self.random_state = np.random.RandomState(seed)
2637

2738
def sample_parameter(self) -> float:
28-
return np.random.uniform(self.min_value, self.max_value)
39+
"""
40+
Draws and returns a sample from the specified interval
41+
"""
42+
return self.random_state.uniform(self.min_value, self.max_value)
2943

3044

3145
class MultiRangeUniformSampler(Sampler):
@@ -36,19 +50,33 @@ class MultiRangeUniformSampler(Sampler):
3650
it proceeds to pick a value uniformly in that range.
3751
"""
3852

39-
def __init__(self, intervals: List[List[Union[int, float]]], **kwargs) -> None:
53+
def __init__(
54+
self,
55+
intervals: List[List[Union[int, float]]],
56+
seed: Optional[int] = None,
57+
**kwargs
58+
) -> None:
59+
"""
60+
:param intervals: List of intervals to draw uniform samples from
61+
:param seed: Random seed used for making uniform draws from the specified intervals
62+
"""
4063
self.intervals = intervals
4164
# Measure the length of the intervals
4265
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals]
4366
cum_interval_length = sum(interval_lengths)
4467
# Assign weights to an interval proportionate to the interval size
4568
self.interval_weights = [x / cum_interval_length for x in interval_lengths]
69+
# Draw from random state to allow for consistent reset parameter draw for a seed
70+
self.random_state = np.random.RandomState(seed)
4671

4772
def sample_parameter(self) -> float:
73+
"""
74+
Selects an interval to pick and then draws a uniform sample from the picked interval
75+
"""
4876
cur_min, cur_max = self.intervals[
49-
np.random.choice(len(self.intervals), p=self.interval_weights)
77+
self.random_state.choice(len(self.intervals), p=self.interval_weights)
5078
]
51-
return np.random.uniform(cur_min, cur_max)
79+
return self.random_state.uniform(cur_min, cur_max)
5280

5381

5482
class GaussianSampler(Sampler):
@@ -58,13 +86,27 @@ class GaussianSampler(Sampler):
5886
"""
5987

6088
def __init__(
61-
self, mean: Union[float, int], st_dev: Union[float, int], **kwargs
89+
self,
90+
mean: Union[float, int],
91+
st_dev: Union[float, int],
92+
seed: Optional[int] = None,
93+
**kwargs
6294
) -> None:
95+
"""
96+
:param mean: Specifies the mean of the gaussian distribution to draw from
97+
:param st_dev: Specifies the standard devation of the gaussian distribution to draw from
98+
:param seed: Random seed used for making gaussian draws from the sample
99+
"""
63100
self.mean = mean
64101
self.st_dev = st_dev
102+
# Draw from random state to allow for consistent reset parameter draw for a seed
103+
self.random_state = np.random.RandomState(seed)
65104

66105
def sample_parameter(self) -> float:
67-
return np.random.normal(self.mean, self.st_dev)
106+
"""
107+
Returns a draw from the specified Gaussian distribution
108+
"""
109+
return self.random_state.normal(self.mean, self.st_dev)
68110

69111

70112
class SamplerFactory:
@@ -81,17 +123,31 @@ class SamplerFactory:
81123

82124
@staticmethod
83125
def register_sampler(name: str, sampler_cls: Type[Sampler]) -> None:
126+
"""
127+
Registers the sampe in the Sampler Factory to be used later
128+
:param name: String name to set as key for the sampler_cls in the factory
129+
:param sampler_cls: Sampler object to associate to the name in the factory
130+
"""
84131
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls
85132

86133
@staticmethod
87-
def init_sampler_class(name: str, params: Dict[str, Any]):
134+
def init_sampler_class(
135+
name: str, params: Dict[str, Any], seed: Optional[int] = None
136+
) -> Sampler:
137+
"""
138+
Initializes the sampler class associated with the name with the params
139+
:param name: Name of the sampler in the factory to initialize
140+
:param params: Parameters associated to the sampler attached to the name
141+
:param seed: Random seed to be used to set deterministic random draws for the sampler
142+
"""
88143
if name not in SamplerFactory.NAME_TO_CLASS:
89144
raise SamplerException(
90145
name + " sampler is not registered in the SamplerFactory."
91146
" Use the register_sample method to register the string"
92147
" associated to your sampler in the SamplerFactory."
93148
)
94149
sampler_cls = SamplerFactory.NAME_TO_CLASS[name]
150+
params["seed"] = seed
95151
try:
96152
return sampler_cls(**params)
97153
except TypeError:
@@ -103,7 +159,13 @@ def init_sampler_class(name: str, params: Dict[str, Any]):
103159

104160

105161
class SamplerManager:
106-
def __init__(self, reset_param_dict: Dict[str, Any]) -> None:
162+
def __init__(
163+
self, reset_param_dict: Dict[str, Any], seed: Optional[int] = None
164+
) -> None:
165+
"""
166+
:param reset_param_dict: Arguments needed for initializing the samplers
167+
:param seed: Random seed to be used for drawing samples from the samplers
168+
"""
107169
self.reset_param_dict = reset_param_dict if reset_param_dict else {}
108170
assert isinstance(self.reset_param_dict, dict)
109171
self.samplers: Dict[str, Sampler] = {}
@@ -116,7 +178,7 @@ def __init__(self, reset_param_dict: Dict[str, Any]) -> None:
116178
)
117179
sampler_name = cur_param_dict.pop("sampler-type")
118180
param_sampler = SamplerFactory.init_sampler_class(
119-
sampler_name, cur_param_dict
181+
sampler_name, cur_param_dict, seed
120182
)
121183

122184
self.samplers[param_name] = param_sampler
@@ -128,6 +190,10 @@ def is_empty(self) -> bool:
128190
return not bool(self.samplers)
129191

130192
def sample_all(self) -> Dict[str, float]:
193+
"""
194+
Loop over all samplers and draw a sample from each one for generating
195+
next set of reset parameter values.
196+
"""
131197
res = {}
132198
for param_name, param_sampler in list(self.samplers.items()):
133199
res[param_name] = param_sampler.sample_parameter()

ml-agents/mlagents/trainers/learn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def run_training(
9191
env = SubprocessEnvManager(env_factory, num_envs)
9292
maybe_meta_curriculum = try_create_meta_curriculum(curriculum_folder, env)
9393
sampler_manager, resampling_interval = create_sampler_manager(
94-
sampler_file_path, env.reset_parameters
94+
sampler_file_path, env.reset_parameters, run_seed
9595
)
9696

9797
# Create controller and begin training.
@@ -118,7 +118,7 @@ def run_training(
118118
tc.start_learning(env, trainer_config)
119119

120120

121-
def create_sampler_manager(sampler_file_path, env_reset_params):
121+
def create_sampler_manager(sampler_file_path, env_reset_params, run_seed=None):
122122
sampler_config = None
123123
resample_interval = None
124124
if sampler_file_path is not None:
@@ -136,7 +136,7 @@ def create_sampler_manager(sampler_file_path, env_reset_params):
136136
"Resampling interval was not specified in the sampler file."
137137
" Please specify it with the 'resampling-interval' key in the sampler config file."
138138
)
139-
sampler_manager = SamplerManager(sampler_config)
139+
sampler_manager = SamplerManager(sampler_config, run_seed)
140140
return sampler_manager, resample_interval
141141

142142

0 commit comments

Comments
 (0)