Skip to content

Commit c7a1799

Browse files
Fix the reporting of histogram stats and adding a test (#5410)
* Fix the reporting of histogram stats and adding a test * Appending to the Changelog
1 parent 8b6dd3d commit c7a1799

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ and this project adheres to
1919
- The calculation of the target entropy of SAC with continuous actions was incorrect and has been fixed. (#5372)
2020
- RigidBodySensorComponent now displays a warning if it's used in a way that won't generate useful observations. (#5387)
2121
- Update the documentation with a note saying that `GridSensor` does not work in 2D environments. (#5396)
22+
- Fixed an issue where the histogram stats would not be reported correctly in TensorBoard. (#5410)
2223

2324

2425
## [2.0.0-exp.1] - 2021-04-22

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
StatsAggregationMethod,
1616
EnvironmentStats,
1717
)
18+
from mlagents.trainers.exception import UnityTrainerException
1819
from mlagents.trainers.trajectory import AgentStatus, Trajectory, AgentExperience
1920
from mlagents.trainers.policy import Policy
2021
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs
@@ -438,8 +439,14 @@ def record_environment_stats(
438439
self._stats_reporter.add_stat(stat_name, val, agg_type)
439440
elif agg_type == StatsAggregationMethod.SUM:
440441
self._stats_reporter.add_stat(stat_name, val, agg_type)
442+
elif agg_type == StatsAggregationMethod.HISTOGRAM:
443+
self._stats_reporter.add_stat(stat_name, val, agg_type)
441444
elif agg_type == StatsAggregationMethod.MOST_RECENT:
442445
# In order to prevent conflicts between multiple environments,
443446
# only stats from the first environment are recorded.
444447
if worker_id == 0:
445448
self._stats_reporter.set_stat(stat_name, val)
449+
else:
450+
raise UnityTrainerException(
451+
f"Unknown StatsAggregationMethod encountered. {agg_type}"
452+
)

ml-agents/mlagents/trainers/tests/test_stats.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import unittest
66
import time
77

8+
89
from mlagents.trainers.stats import (
910
StatsReporter,
1011
TensorboardWriter,
@@ -15,6 +16,8 @@
1516
StatsAggregationMethod,
1617
)
1718

19+
from mlagents.trainers.env_manager import AgentManager
20+
1821

1922
def test_stat_reporter_add_summary_write():
2023
# Test add_writer
@@ -107,6 +110,27 @@ def test_tensorboard_writer(mock_summary):
107110
assert mock_summary.return_value.add_text.call_count >= 1
108111

109112

113+
@pytest.mark.parametrize("aggregation_type", list(StatsAggregationMethod))
114+
def test_agent_manager_stats_report(aggregation_type):
115+
stats_reporter = StatsReporter("recorder_name")
116+
manager = AgentManager(None, "behaviorName", stats_reporter)
117+
118+
values = range(5)
119+
120+
env_stats = {"stat": [(i, aggregation_type) for i in values]}
121+
manager.record_environment_stats(env_stats, 0)
122+
summary = stats_reporter.get_stats_summaries("stat")
123+
aggregation_result = {
124+
StatsAggregationMethod.AVERAGE: sum(values) / len(values),
125+
StatsAggregationMethod.MOST_RECENT: values[-1],
126+
StatsAggregationMethod.SUM: sum(values),
127+
StatsAggregationMethod.HISTOGRAM: sum(values) / len(values),
128+
}
129+
130+
assert summary.aggregated_value == aggregation_result[aggregation_type]
131+
stats_reporter.write_stats(0)
132+
133+
110134
def test_tensorboard_writer_clear(tmp_path):
111135
tb_writer = TensorboardWriter(tmp_path, clear_past_data=False)
112136
statssummary1 = StatsSummary(

0 commit comments

Comments
 (0)