Skip to content

Commit 5d7dd57

Browse files
sankalp04Ervin T
authored andcommitted
Enable generalization training (#2232)
* Add Sampler and SamplerManager * Enable resampling of reset parameters during training * Documentation for Sampler and example YAML configuration file
1 parent b7dcda6 commit 5d7dd57

File tree

14 files changed

+525
-43
lines changed

14 files changed

+525
-43
lines changed

config/generalize_test.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
resampling-interval: 5000
2+
3+
mass:
4+
sampler-type: "uniform"
5+
min_value: 0.5
6+
max_value: 10
7+
8+
gravity:
9+
sampler-type: "uniform"
10+
min_value: 7
11+
max_value: 12
12+
13+
scale:
14+
sampler-type: "uniform"
15+
min_value: 0.75
16+
max_value: 3
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Training Generalized Reinforcement Learning Agents
2+
3+
Reinforcement learning has a rather unique setup as opposed to supervised and
4+
unsupervised learning. Agents here are trained and tested on the same exact
5+
environment, which is analogous to a model being trained and tested on an
6+
identical dataset in supervised learning! This setting results in overfitting;
7+
the inability of the agent to generalize to slight tweaks or variations in the
8+
environment. This is problematic in instances when environments are randomly
9+
instantiated with varying properties. To make agents robust, one approach is to
10+
train an agent over multiple variations of the environment. The agent is
11+
trained in this approach with the intent that it learns to adapt its performance
12+
to future unseen variations of the environment.
13+
14+
Ball scale of 0.5 | Ball scale of 4
15+
:-------------------------:|:-------------------------:
16+
![](images/3dball_small.png) | ![](images/3dball_big.png)
17+
18+
_Variations of the 3D Ball environment._
19+
20+
To vary environments, we first decide what parameters to vary in an
21+
environment. These parameters are known as `Reset Parameters`. In the 3D ball
22+
environment example displayed in the figure above, the reset parameters are `gravity`, `ball_mass` and `ball_scale`.
23+
24+
25+
## How-to
26+
27+
For generalization training, we need to provide a way to modify the environment
28+
by supplying a set of reset parameters, and vary them over time. This provision
29+
can be done either deterministically or randomly.
30+
31+
This is done by assigning each reset parameter a sampler, which samples a reset
32+
parameter value (such as a uniform sampler). If a sampler isn't provided for a
33+
reset parameter, the parameter maintains the default value throughout the
34+
training, remaining unchanged. The samplers for all the reset parameters are
35+
handled by a **Sampler Manager**, which also handles the generation of new
36+
values for the reset parameters when needed.
37+
38+
To setup the Sampler Manager, we setup a YAML file that specifies how we wish to
39+
generate new samples. In this file, we specify the samplers and the
40+
`resampling-duration` (number of simulation steps after which reset parameters are
41+
resampled). Below is an example of a sampler file for the 3D ball environment.
42+
43+
```yaml
44+
episode-length: 5000
45+
46+
mass:
47+
sampler-type: "uniform"
48+
min_value: 0.5
49+
max_value: 10
50+
51+
gravity:
52+
sampler-type: "multirange_uniform"
53+
intervals: [[7, 10], [15, 20]]
54+
55+
scale:
56+
sampler-type: "uniform"
57+
min_value: 0.75
58+
max_value: 3
59+
60+
```
61+
62+
* `resampling-duration` (int) - Specifies the number of steps for agent to
63+
train under a particular environment configuration before resetting the
64+
environment with a new sample of reset parameters.
65+
66+
* `parameter_name` - Name of the reset parameter. This should match the name
67+
specified in the academy of the intended environment for which the agent is
68+
being trained. If a parameter specified in the file doesn't exist in the
69+
environment, then this specification will be ignored.
70+
71+
* `sampler-type` - Specify the sampler type to use for the reset parameter.
72+
This is a string that should exist in the `Sampler Factory` (explained
73+
below).
74+
75+
* `sub-arguments` - Specify the characteristic parameters for the sampler.
76+
In the example sampler file above, this would correspond to the `intervals`
77+
key under the `multirange_uniform` sampler for the gravity reset parameter.
78+
The key name should match the name of the corresponding argument in the sampler definition. (Look at defining a new sampler method)
79+
80+
The sampler manager allocates a sampler for a reset parameter by using the *Sampler Factory*, which maintains a dictionary mapping of string keys to sampler objects. The available samplers to be used for reset parameter resampling is as available in the Sampler Factory.
81+
82+
The implementation of the samplers can be found at `ml-agents-envs/mlagents/envs/sampler_class.py`.
83+
84+
### Defining a new sampler method
85+
86+
Custom sampling techniques must inherit from the *Sampler* base class (included in the `sampler_class` file) and preserve the interface. Once the class for the required method is specified, it must be registered in the Sampler Factory.
87+
88+
This can be done by subscribing to the *register_sampler* method of the SamplerFactory. The command is as follows:
89+
90+
`SamplerFactory.register_sampler(*custom_sampler_string_key*, *custom_sampler_object*)`
91+
92+
Once the Sampler Factory reflects the new register, the custom sampler can be used for resampling reset parameter. For demonstration, lets say our sampler was implemented as below, and we register the `CustomSampler` class with the string `custom-sampler` in the Sampler Factory.
93+
94+
```python
95+
class CustomSampler(Sampler):
96+
97+
def __init__(self, argA, argB, argC):
98+
self.possible_vals = [argA, argB, argC]
99+
100+
def sample_all(self):
101+
return np.random.choice(self.possible_vals)
102+
```
103+
104+
Now we need to specify this sampler in the sampler file. Lets say we wish to use this sampler for the reset parameter *mass*; the sampler file would specify the same for mass as the following (any order of the subarguments is valid).
105+
106+
```yaml
107+
mass:
108+
sampler-type: "custom-sampler"
109+
argB: 1
110+
argA: 2
111+
argC: 3
112+
```
113+
114+
With the sampler file setup, we can proceed to train our agent as explained in the next section.
115+
116+
### Training with Generalization Learning
117+
118+
We first begin with setting up the sampler file. After the sampler file is defined and configured, we proceed by launching `mlagents-learn` and specify our configured sampler file with the `--sampler` flag. To demonstrate, if we wanted to train a 3D ball agent with generalization using the `config/generalization-test.yaml` sampling setup, we can run
119+
120+
```sh
121+
mlagents-learn config/trainer_config.yaml --sampler=config/generalize_test.yaml --run-id=3D-Ball-generalization --train
122+
```
123+
124+
We can observe progress and metrics via Tensorboard.

docs/Training-ML-Agents.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ environment, you can set the following command line options when invoking
101101
* `--curriculum=<file>` – Specify a curriculum JSON file for defining the
102102
lessons for curriculum training. See [Curriculum
103103
Training](Training-Curriculum-Learning.md) for more information.
104+
* `--sampler=<file>` - Specify a sampler YAML file for defining the
105+
sampler for generalization training. See [Generalization
106+
Training](Training-Generalization-Learning.md) for more information.
104107
* `--keep-checkpoints=<n>` – Specify the maximum number of model checkpoints to
105108
keep. Checkpoints are saved after the number of steps specified by the
106109
`save-freq` option. Once the maximum number of checkpoints has been reached,
@@ -194,6 +197,7 @@ are conducting, see:
194197
* [Training with PPO](Training-PPO.md)
195198
* [Using Recurrent Neural Networks](Feature-Memory.md)
196199
* [Training with Curriculum Learning](Training-Curriculum-Learning.md)
200+
* [Training with Generalization](Training-Generalization-Learning.md)
197201
* [Training with Imitation Learning](Training-Imitation-Learning.md)
198202

199203
You can also compare the

docs/images/3dball_big.png

196 KB
Loading

docs/images/3dball_small.png

139 KB
Loading

ml-agents-envs/mlagents/envs/exception.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ class UnityActionException(UnityException):
2727
pass
2828

2929

30+
class SamplerException(UnityException):
31+
"""
32+
Related to errors with the sampler actions.
33+
"""
34+
35+
pass
36+
37+
3038
class UnityTimeOutException(UnityException):
3139
"""
3240
Related to errors with communication timeouts.
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import numpy as np
2+
from typing import *
3+
from functools import *
4+
from collections import OrderedDict
5+
from abc import ABC, abstractmethod
6+
7+
from .exception import SamplerException
8+
9+
10+
class Sampler(ABC):
11+
@abstractmethod
12+
def sample_parameter(self) -> float:
13+
pass
14+
15+
16+
class UniformSampler(Sampler):
17+
"""
18+
Uniformly draws a single sample in the range [min_value, max_value).
19+
"""
20+
21+
def __init__(
22+
self, min_value: Union[int, float], max_value: Union[int, float], **kwargs
23+
) -> None:
24+
self.min_value = min_value
25+
self.max_value = max_value
26+
27+
def sample_parameter(self) -> float:
28+
return np.random.uniform(self.min_value, self.max_value)
29+
30+
31+
class MultiRangeUniformSampler(Sampler):
32+
"""
33+
Draws a single sample uniformly from the intervals provided. The sampler
34+
first picks an interval based on a weighted selection, with the weights
35+
assigned to an interval based on its range. After picking the range,
36+
it proceeds to pick a value uniformly in that range.
37+
"""
38+
39+
def __init__(self, intervals: List[List[Union[int, float]]], **kwargs) -> None:
40+
self.intervals = intervals
41+
# Measure the length of the intervals
42+
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals]
43+
cum_interval_length = sum(interval_lengths)
44+
# Assign weights to an interval proportionate to the interval size
45+
self.interval_weights = [x / cum_interval_length for x in interval_lengths]
46+
47+
def sample_parameter(self) -> float:
48+
cur_min, cur_max = self.intervals[
49+
np.random.choice(len(self.intervals), p=self.interval_weights)
50+
]
51+
return np.random.uniform(cur_min, cur_max)
52+
53+
54+
class GaussianSampler(Sampler):
55+
"""
56+
Draw a single sample value from a normal (gaussian) distribution.
57+
This sampler is characterized by the mean and the standard deviation.
58+
"""
59+
60+
def __init__(
61+
self, mean: Union[float, int], st_dev: Union[float, int], **kwargs
62+
) -> None:
63+
self.mean = mean
64+
self.st_dev = st_dev
65+
66+
def sample_parameter(self) -> float:
67+
return np.random.normal(self.mean, self.st_dev)
68+
69+
70+
class SamplerFactory:
71+
"""
72+
Maintain a directory of all samplers available.
73+
Add new samplers using the register_sampler method.
74+
"""
75+
76+
NAME_TO_CLASS = {
77+
"uniform": UniformSampler,
78+
"gaussian": GaussianSampler,
79+
"multirange_uniform": MultiRangeUniformSampler,
80+
}
81+
82+
@staticmethod
83+
def register_sampler(name: str, sampler_cls: Type[Sampler]) -> None:
84+
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls
85+
86+
@staticmethod
87+
def init_sampler_class(name: str, params: Dict[str, Any]):
88+
if name not in SamplerFactory.NAME_TO_CLASS:
89+
raise SamplerException(
90+
name + " sampler is not registered in the SamplerFactory."
91+
" Use the register_sample method to register the string"
92+
" associated to your sampler in the SamplerFactory."
93+
)
94+
sampler_cls = SamplerFactory.NAME_TO_CLASS[name]
95+
try:
96+
return sampler_cls(**params)
97+
except TypeError:
98+
raise SamplerException(
99+
"The sampler class associated to the " + name + " key in the factory "
100+
"was not provided the required arguments. Please ensure that the sampler "
101+
"config file consists of the appropriate keys for this sampler class."
102+
)
103+
104+
105+
class SamplerManager:
106+
def __init__(self, reset_param_dict: Dict[str, Any]) -> None:
107+
self.reset_param_dict = reset_param_dict if reset_param_dict else {}
108+
assert isinstance(self.reset_param_dict, dict)
109+
self.samplers: Dict[str, Sampler] = {}
110+
for param_name, cur_param_dict in self.reset_param_dict.items():
111+
if "sampler-type" not in cur_param_dict:
112+
raise SamplerException(
113+
"'sampler_type' argument hasn't been supplied for the {0} parameter".format(
114+
param_name
115+
)
116+
)
117+
sampler_name = cur_param_dict.pop("sampler-type")
118+
param_sampler = SamplerFactory.init_sampler_class(
119+
sampler_name, cur_param_dict
120+
)
121+
122+
self.samplers[param_name] = param_sampler
123+
124+
def is_empty(self) -> bool:
125+
"""
126+
Check for if sampler_manager is empty.
127+
"""
128+
return not bool(self.samplers)
129+
130+
def sample_all(self) -> Dict[str, float]:
131+
res = {}
132+
for param_name, param_sampler in list(self.samplers.items()):
133+
res[param_name] = param_sampler.sample_parameter()
134+
return res

0 commit comments

Comments
 (0)