Skip to content

Commit ed3689c

Browse files
andrewcohChris Elion
andauthored
add to_string for samplers (#4484)
Co-authored-by: Chris Elion <[email protected]>
1 parent 53c27ee commit ed3689c

File tree

3 files changed

+54
-4
lines changed

3 files changed

+54
-4
lines changed

ml-agents/mlagents/trainers/environment_parameter_manager.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,11 @@ def update_lessons(
128128
lesson_num = GlobalTrainingStatus.get_parameter_state(
129129
param_name, StatusType.LESSON_NUM
130130
)
131+
next_lesson_num = lesson_num + 1
131132
lesson = settings.curriculum[lesson_num]
132133
if (
133134
lesson.completion_criteria is not None
134-
and len(settings.curriculum) > lesson_num + 1
135+
and len(settings.curriculum) > next_lesson_num
135136
):
136137
behavior_to_consider = lesson.completion_criteria.behavior
137138
if behavior_to_consider in trainer_steps:
@@ -144,11 +145,14 @@ def update_lessons(
144145
self._smoothed_values[param_name] = new_smoothing
145146
if must_increment:
146147
GlobalTrainingStatus.set_parameter_state(
147-
param_name, StatusType.LESSON_NUM, lesson_num + 1
148+
param_name, StatusType.LESSON_NUM, next_lesson_num
148149
)
149-
new_lesson_name = settings.curriculum[lesson_num + 1].name
150+
new_lesson_name = settings.curriculum[next_lesson_num].name
151+
new_lesson_value = settings.curriculum[next_lesson_num].value
152+
150153
logger.info(
151-
f"Parameter '{param_name}' has changed. Now in lesson '{new_lesson_name}'"
154+
f"Parameter '{param_name}' has been updated to {new_lesson_value}."
155+
+ f" Now in lesson '{new_lesson_name}'"
152156
)
153157
updated = True
154158
if lesson.completion_criteria.require_reset:

ml-agents/mlagents/trainers/settings.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,12 @@ def to_settings(self) -> type:
244244
class ParameterRandomizationSettings(abc.ABC):
245245
seed: int = parser.get_default("seed")
246246

247+
def __str__(self) -> str:
248+
"""
249+
Helper method to output sampler stats to console.
250+
"""
251+
raise TrainerConfigError(f"__str__ not implemented for type {self.__class__}.")
252+
247253
@staticmethod
248254
def structure(
249255
d: Union[Mapping, float], t: type
@@ -305,6 +311,12 @@ def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None:
305311
class ConstantSettings(ParameterRandomizationSettings):
306312
value: float = 0.0
307313

314+
def __str__(self) -> str:
315+
"""
316+
Helper method to output sampler stats to console.
317+
"""
318+
return f"Float: value={self.value}"
319+
308320
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None:
309321
"""
310322
Helper method to send sampler settings over EnvironmentParametersChannel
@@ -320,6 +332,12 @@ class UniformSettings(ParameterRandomizationSettings):
320332
min_value: float = attr.ib()
321333
max_value: float = 1.0
322334

335+
def __str__(self) -> str:
336+
"""
337+
Helper method to output sampler stats to console.
338+
"""
339+
return f"Uniform sampler: min={self.min_value}, max={self.max_value}"
340+
323341
@min_value.default
324342
def _min_value_default(self):
325343
return 0.0
@@ -348,6 +366,12 @@ class GaussianSettings(ParameterRandomizationSettings):
348366
mean: float = 1.0
349367
st_dev: float = 1.0
350368

369+
def __str__(self) -> str:
370+
"""
371+
Helper method to output sampler stats to console.
372+
"""
373+
return f"Gaussian sampler: mean={self.mean}, stddev={self.st_dev}"
374+
351375
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None:
352376
"""
353377
Helper method to send sampler settings over EnvironmentParametersChannel
@@ -364,6 +388,12 @@ def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None:
364388
class MultiRangeUniformSettings(ParameterRandomizationSettings):
365389
intervals: List[Tuple[float, float]] = attr.ib()
366390

391+
def __str__(self) -> str:
392+
"""
393+
Helper method to output sampler stats to console.
394+
"""
395+
return f"MultiRangeUniform sampler: intervals={self.intervals}"
396+
367397
@intervals.default
368398
def _intervals_default(self):
369399
return [[0.0, 1.0]]

ml-agents/mlagents/trainers/tests/test_settings.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,22 @@ def test_env_parameter_structure():
244244
assert isinstance(
245245
env_param_settings["length"].curriculum[0].value, MultiRangeUniformSettings
246246
)
247+
248+
# Check __str__ is correct
249+
assert (
250+
str(env_param_settings["mass"].curriculum[0].value)
251+
== "Uniform sampler: min=1.0, max=2.0"
252+
)
253+
assert (
254+
str(env_param_settings["scale"].curriculum[0].value)
255+
== "Gaussian sampler: mean=1.0, stddev=2.0"
256+
)
257+
assert (
258+
str(env_param_settings["length"].curriculum[0].value)
259+
== "MultiRangeUniform sampler: intervals=[(1.0, 2.0), (3.0, 4.0)]"
260+
)
261+
assert str(env_param_settings["gravity"].curriculum[0].value) == "Float: value=1"
262+
247263
assert isinstance(
248264
env_param_settings["wall_height"].curriculum[0].value, ConstantSettings
249265
)

0 commit comments

Comments
 (0)