|
1 | 1 | import os
|
| 2 | +import unittest |
2 | 3 | from unittest import mock
|
3 | 4 | import pytest
|
4 | 5 | import mlagents.trainers.tests.mock_brain as mb
|
@@ -178,3 +179,34 @@ def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary):
|
178 | 179 | for step in checkpoint_range
|
179 | 180 | ]
|
180 | 181 | mock_add_checkpoint.assert_has_calls(add_checkpoint_calls)
|
| 182 | + |
| 183 | + |
| 184 | +class RLTrainerWarningTest(unittest.TestCase): |
| 185 | + def test_warning_group_reward(self): |
| 186 | + with self.assertLogs("mlagents.trainers", level="WARN") as cm: |
| 187 | + rl_trainer = create_rl_trainer() |
| 188 | + # This one should warn |
| 189 | + trajectory = mb.make_fake_trajectory( |
| 190 | + length=10, |
| 191 | + observation_specs=create_observation_specs_with_shapes([(1,)]), |
| 192 | + max_step_complete=True, |
| 193 | + action_spec=ActionSpec.create_discrete((2,)), |
| 194 | + group_reward=1.0, |
| 195 | + ) |
| 196 | + buff = trajectory.to_agentbuffer() |
| 197 | + rl_trainer._warn_if_group_reward(buff) |
| 198 | + assert len(cm.output) > 0 |
| 199 | + len_of_first_warning = len(cm.output) |
| 200 | + |
| 201 | + rl_trainer = create_rl_trainer() |
| 202 | + # This one shouldn't |
| 203 | + trajectory = mb.make_fake_trajectory( |
| 204 | + length=10, |
| 205 | + observation_specs=create_observation_specs_with_shapes([(1,)]), |
| 206 | + max_step_complete=True, |
| 207 | + action_spec=ActionSpec.create_discrete((2,)), |
| 208 | + ) |
| 209 | + buff = trajectory.to_agentbuffer() |
| 210 | + rl_trainer._warn_if_group_reward(buff) |
| 211 | + # Make sure warnings don't get bigger |
| 212 | + assert len(cm.output) == len_of_first_warning |
0 commit comments