Skip to content

Commit 348d1d6

Browse files
author
Ervin T
authored
Change AgentProcessor logic to fix memory leak (#3383)
1 parent 3c940e8 commit 348d1d6

File tree

3 files changed

+88
-17
lines changed

3 files changed

+88
-17
lines changed

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import sys
2-
from typing import List, Dict, Deque, TypeVar, Generic
2+
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set
33
from collections import defaultdict, Counter, deque
44

5-
from mlagents_envs.base_env import BatchedStepResult
5+
from mlagents_envs.base_env import BatchedStepResult, StepResult
66
from mlagents.trainers.trajectory import Trajectory, AgentExperience
77
from mlagents.trainers.tf_policy import TFPolicy
88
from mlagents.trainers.policy import Policy
@@ -36,7 +36,7 @@ def __init__(
3636
:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer.
3737
"""
3838
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list)
39-
self.last_step_result: Dict[str, BatchedStepResult] = {}
39+
self.last_step_result: Dict[str, Tuple[StepResult, int]] = {}
4040
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while
4141
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1).
4242
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {}
@@ -69,28 +69,27 @@ def add_experiences(
6969
"Policy/Learning Rate", take_action_outputs["learning_rate"]
7070
)
7171

72-
terminated_agents: List[str] = []
72+
terminated_agents: Set[str] = set()
7373
# Make unique agent_ids that are global across workers
7474
action_global_agent_ids = [
7575
get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids
7676
]
7777
for global_id in action_global_agent_ids:
78-
self.last_take_action_outputs[global_id] = take_action_outputs
78+
if global_id in self.last_step_result: # Don't store if agent just reset
79+
self.last_take_action_outputs[global_id] = take_action_outputs
7980

8081
for _id in batched_step_result.agent_id: # Assume agent_id is 1-D
8182
local_id = int(
8283
_id
8384
) # Needed for mypy to pass since ndarray has no content type
8485
curr_agent_step = batched_step_result.get_agent_step_result(local_id)
8586
global_id = get_global_agent_id(worker_id, local_id)
86-
stored_step = self.last_step_result.get(global_id, None)
87+
stored_agent_step, idx = self.last_step_result.get(global_id, (None, None))
8788
stored_take_action_outputs = self.last_take_action_outputs.get(
8889
global_id, None
8990
)
90-
if stored_step is not None and stored_take_action_outputs is not None:
91+
if stored_agent_step is not None and stored_take_action_outputs is not None:
9192
# We know the step is from the same worker, so use the local agent id.
92-
stored_agent_step = stored_step.get_agent_step_result(local_id)
93-
idx = stored_step.agent_id_to_index[local_id]
9493
obs = stored_agent_step.obs
9594
if not stored_agent_step.done:
9695
if self.policy.use_recurrent:
@@ -155,29 +154,37 @@ def add_experiences(
155154
"Environment/Episode Length",
156155
self.episode_steps.get(global_id, 0),
157156
)
158-
terminated_agents += [global_id]
157+
terminated_agents.add(global_id)
159158
elif not curr_agent_step.done:
160159
self.episode_steps[global_id] += 1
161160

162-
self.last_step_result[global_id] = batched_step_result
163-
164-
if "action" in take_action_outputs:
165-
self.policy.save_previous_action(
166-
previous_action.agent_ids, take_action_outputs["action"]
161+
# Index is needed to grab from last_take_action_outputs
162+
self.last_step_result[global_id] = (
163+
curr_agent_step,
164+
batched_step_result.agent_id_to_index[_id],
167165
)
168166

169167
for terminated_id in terminated_agents:
170168
self._clean_agent_data(terminated_id)
171169

170+
for _gid in action_global_agent_ids:
171+
# If the ID doesn't have a last step result, the agent just reset,
172+
# don't store the action.
173+
if _gid in self.last_step_result:
174+
if "action" in take_action_outputs:
175+
self.policy.save_previous_action(
176+
[_gid], take_action_outputs["action"]
177+
)
178+
172179
def _clean_agent_data(self, global_id: str) -> None:
173180
"""
174181
Removes the data for an Agent.
175182
"""
176183
del self.experience_buffers[global_id]
177184
del self.last_take_action_outputs[global_id]
185+
del self.last_step_result[global_id]
178186
del self.episode_steps[global_id]
179187
del self.episode_rewards[global_id]
180-
del self.last_step_result[global_id]
181188
self.policy.remove_previous_action([global_id])
182189
self.policy.remove_memories([global_id])
183190

ml-agents/mlagents/trainers/tests/mock_brain.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def create_mock_batchedstep(
3939
num_vis_observations: int = 0,
4040
action_shape: List[int] = None,
4141
discrete: bool = False,
42+
done: bool = False,
4243
) -> BatchedStepResult:
4344
"""
4445
Creates a mock BatchedStepResult with observations. Imitates constant
@@ -68,7 +69,7 @@ def create_mock_batchedstep(
6869
]
6970

7071
reward = np.array(num_agents * [1.0], dtype=np.float32)
71-
done = np.array(num_agents * [False], dtype=np.bool)
72+
done = np.array(num_agents * [done], dtype=np.bool)
7273
max_step = np.array(num_agents * [False], dtype=np.bool)
7374
agent_id = np.arange(num_agents, dtype=np.int32)
7475

ml-agents/mlagents/trainers/tests/test_agent_processor.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mlagents.trainers.action_info import ActionInfo
1111
from mlagents.trainers.trajectory import Trajectory
1212
from mlagents.trainers.stats import StatsReporter
13+
from mlagents.trainers.brain_conversion_utils import get_global_agent_id
1314

1415

1516
def create_mock_brain():
@@ -91,6 +92,68 @@ def test_agentprocessor(num_vis_obs):
9192
assert len(processor.experience_buffers[0]) == 0
9293

9394

95+
def test_agent_deletion():
96+
policy = create_mock_policy()
97+
tqueue = mock.Mock()
98+
name_behavior_id = "test_brain_name"
99+
processor = AgentProcessor(
100+
policy,
101+
name_behavior_id,
102+
max_trajectory_length=5,
103+
stats_reporter=StatsReporter("testcat"),
104+
)
105+
106+
fake_action_outputs = {
107+
"action": [0.1],
108+
"entropy": np.array([1.0], dtype=np.float32),
109+
"learning_rate": 1.0,
110+
"pre_action": [0.1],
111+
"log_probs": [0.1],
112+
}
113+
mock_step = mb.create_mock_batchedstep(
114+
num_agents=1,
115+
num_vector_observations=8,
116+
action_shape=[2],
117+
num_vis_observations=0,
118+
)
119+
mock_done_step = mb.create_mock_batchedstep(
120+
num_agents=1,
121+
num_vector_observations=8,
122+
action_shape=[2],
123+
num_vis_observations=0,
124+
done=True,
125+
)
126+
fake_action_info = ActionInfo(
127+
action=[0.1],
128+
value=[0.1],
129+
outputs=fake_action_outputs,
130+
agent_ids=mock_step.agent_id,
131+
)
132+
133+
processor.publish_trajectory_queue(tqueue)
134+
# This is like the initial state after the env reset
135+
processor.add_experiences(mock_step, 0, ActionInfo.empty())
136+
137+
# Run 3 trajectories, with different workers (to simulate different agents)
138+
add_calls = []
139+
remove_calls = []
140+
for _ep in range(3):
141+
for _ in range(5):
142+
processor.add_experiences(mock_step, _ep, fake_action_info)
143+
add_calls.append(mock.call([get_global_agent_id(_ep, 0)], [0.1]))
144+
processor.add_experiences(mock_done_step, _ep, fake_action_info)
145+
# Make sure we don't add experiences from the prior agents after the done
146+
remove_calls.append(mock.call([get_global_agent_id(_ep, 0)]))
147+
148+
policy.save_previous_action.assert_has_calls(add_calls)
149+
policy.remove_previous_action.assert_has_calls(remove_calls)
150+
# Check that there are no experiences left
151+
assert len(processor.experience_buffers.keys()) == 0
152+
assert len(processor.last_take_action_outputs.keys()) == 0
153+
assert len(processor.episode_steps.keys()) == 0
154+
assert len(processor.episode_rewards.keys()) == 0
155+
156+
94157
def test_agent_manager():
95158
policy = create_mock_policy()
96159
name_behavior_id = "test_brain_name"

0 commit comments

Comments
 (0)