Skip to content

Commit d230591

Browse files
author
Ervin T
authored
Fix clear update buffer when trainer stops training, add test (#3422)
* Fix clear update buffer when trainer stops training, add test * Fix buffer changing types when truncated
1 parent ccd6144 commit d230591

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

ml-agents/mlagents/trainers/buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def truncate(self, max_length: int, sequence_length: int = 1) -> None:
253253
max_length -= max_length % sequence_length
254254
if current_length > max_length:
255255
for _key in self.keys():
256-
self[_key] = self[_key][current_length - max_length :]
256+
self[_key][:] = self[_key][current_length - max_length :]
257257

258258
def resequence_and_append(
259259
self,

ml-agents/mlagents/trainers/rl_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,5 @@ def advance(self) -> None:
7373
Steps the trainer, taking in trajectories and updates if ready
7474
"""
7575
super().advance()
76-
if not self.is_training:
76+
if not self.should_still_train:
7777
self.clear_update_buffer()

ml-agents/mlagents/trainers/tests/test_buffer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,5 @@ def test_buffer_truncate():
152152
# Test LSTM, truncate should be some multiple of sequence_length
153153
update_buffer.truncate(4, sequence_length=3)
154154
assert update_buffer.num_experiences == 3
155+
for buffer_field in update_buffer.values():
156+
assert isinstance(buffer_field, AgentBuffer.AgentBufferField)

ml-agents/mlagents/trainers/tests/test_rl_trainer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import mlagents.trainers.tests.mock_brain as mb
44
from mlagents.trainers.rl_trainer import RLTrainer
55
from mlagents.trainers.tests.test_buffer import construct_fake_buffer
6+
from mlagents.trainers.agent_processor import AgentManagerQueue
67

78

89
def dummy_config():
910
return yaml.safe_load(
1011
"""
1112
summary_path: "test/"
1213
summary_freq: 1000
14+
max_steps: 100
1315
reward_signals:
1416
extrinsic:
1517
strength: 1.0
@@ -75,3 +77,31 @@ def test_clear_update_buffer():
7577
trainer.clear_update_buffer()
7678
for _, arr in trainer.update_buffer.items():
7779
assert len(arr) == 0
80+
81+
82+
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.clear_update_buffer")
83+
def test_advance(mocked_clear_update_buffer):
84+
trainer = create_rl_trainer()
85+
trajectory_queue = AgentManagerQueue("testbrain")
86+
trainer.subscribe_trajectory_queue(trajectory_queue)
87+
time_horizon = 15
88+
trajectory = mb.make_fake_trajectory(
89+
length=time_horizon,
90+
max_step_complete=True,
91+
vec_obs_size=1,
92+
num_vis_obs=0,
93+
action_space=[2],
94+
)
95+
trajectory_queue.put(trajectory)
96+
97+
trainer.advance()
98+
# Check that get_step is correct
99+
assert trainer.get_step == time_horizon
100+
# Check that we can turn off the trainer and that the buffer is cleared
101+
for _ in range(0, 10):
102+
trajectory_queue.put(trajectory)
103+
trainer.advance()
104+
105+
# Check that the buffer has been cleared
106+
assert not trainer.should_still_train
107+
assert mocked_clear_update_buffer.call_count > 0

0 commit comments

Comments
 (0)