|
1 | 1 | import sys |
2 | | -from typing import List, Dict, Deque, TypeVar, Generic |
| 2 | +from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set |
3 | 3 | from collections import defaultdict, Counter, deque |
4 | 4 |
|
5 | | -from mlagents_envs.base_env import BatchedStepResult |
| 5 | +from mlagents_envs.base_env import BatchedStepResult, StepResult |
6 | 6 | from mlagents.trainers.trajectory import Trajectory, AgentExperience |
7 | 7 | from mlagents.trainers.tf_policy import TFPolicy |
8 | 8 | from mlagents.trainers.policy import Policy |
@@ -36,7 +36,7 @@ def __init__( |
36 | 36 | :param stats_category: The category under which to write the stats. Usually, this comes from the Trainer. |
37 | 37 | """ |
38 | 38 | self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) |
39 | | - self.last_step_result: Dict[str, BatchedStepResult] = {} |
| 39 | + self.last_step_result: Dict[str, Tuple[StepResult, int]] = {} |
40 | 40 | # last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while |
41 | 41 | # grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1). |
42 | 42 | self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {} |
@@ -69,28 +69,27 @@ def add_experiences( |
69 | 69 | "Policy/Learning Rate", take_action_outputs["learning_rate"] |
70 | 70 | ) |
71 | 71 |
|
72 | | - terminated_agents: List[str] = [] |
| 72 | + terminated_agents: Set[str] = set() |
73 | 73 | # Make unique agent_ids that are global across workers |
74 | 74 | action_global_agent_ids = [ |
75 | 75 | get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids |
76 | 76 | ] |
77 | 77 | for global_id in action_global_agent_ids: |
78 | | - self.last_take_action_outputs[global_id] = take_action_outputs |
| 78 | + if global_id in self.last_step_result: # Don't store if agent just reset |
| 79 | + self.last_take_action_outputs[global_id] = take_action_outputs |
79 | 80 |
|
80 | 81 | for _id in batched_step_result.agent_id: # Assume agent_id is 1-D |
81 | 82 | local_id = int( |
82 | 83 | _id |
83 | 84 | ) # Needed for mypy to pass since ndarray has no content type |
84 | 85 | curr_agent_step = batched_step_result.get_agent_step_result(local_id) |
85 | 86 | global_id = get_global_agent_id(worker_id, local_id) |
86 | | - stored_step = self.last_step_result.get(global_id, None) |
| 87 | + stored_agent_step, idx = self.last_step_result.get(global_id, (None, None)) |
87 | 88 | stored_take_action_outputs = self.last_take_action_outputs.get( |
88 | 89 | global_id, None |
89 | 90 | ) |
90 | | - if stored_step is not None and stored_take_action_outputs is not None: |
| 91 | + if stored_agent_step is not None and stored_take_action_outputs is not None: |
91 | 92 | # We know the step is from the same worker, so use the local agent id. |
92 | | - stored_agent_step = stored_step.get_agent_step_result(local_id) |
93 | | - idx = stored_step.agent_id_to_index[local_id] |
94 | 93 | obs = stored_agent_step.obs |
95 | 94 | if not stored_agent_step.done: |
96 | 95 | if self.policy.use_recurrent: |
@@ -155,29 +154,37 @@ def add_experiences( |
155 | 154 | "Environment/Episode Length", |
156 | 155 | self.episode_steps.get(global_id, 0), |
157 | 156 | ) |
158 | | - terminated_agents += [global_id] |
| 157 | + terminated_agents.add(global_id) |
159 | 158 | elif not curr_agent_step.done: |
160 | 159 | self.episode_steps[global_id] += 1 |
161 | 160 |
|
162 | | - self.last_step_result[global_id] = batched_step_result |
163 | | - |
164 | | - if "action" in take_action_outputs: |
165 | | - self.policy.save_previous_action( |
166 | | - previous_action.agent_ids, take_action_outputs["action"] |
| 161 | + # Index is needed to grab from last_take_action_outputs |
| 162 | + self.last_step_result[global_id] = ( |
| 163 | + curr_agent_step, |
| 164 | + batched_step_result.agent_id_to_index[_id], |
167 | 165 | ) |
168 | 166 |
|
169 | 167 | for terminated_id in terminated_agents: |
170 | 168 | self._clean_agent_data(terminated_id) |
171 | 169 |
|
| 170 | + for _gid in action_global_agent_ids: |
| 171 | + # If the ID doesn't have a last step result, the agent just reset, |
| 172 | + # don't store the action. |
| 173 | + if _gid in self.last_step_result: |
| 174 | + if "action" in take_action_outputs: |
| 175 | + self.policy.save_previous_action( |
| 176 | + [_gid], take_action_outputs["action"] |
| 177 | + ) |
| 178 | + |
172 | 179 | def _clean_agent_data(self, global_id: str) -> None: |
173 | 180 | """ |
174 | 181 | Removes the data for an Agent. |
175 | 182 | """ |
176 | 183 | del self.experience_buffers[global_id] |
177 | 184 | del self.last_take_action_outputs[global_id] |
| 185 | + del self.last_step_result[global_id] |
178 | 186 | del self.episode_steps[global_id] |
179 | 187 | del self.episode_rewards[global_id] |
180 | | - del self.last_step_result[global_id] |
181 | 188 | self.policy.remove_previous_action([global_id]) |
182 | 189 | self.policy.remove_memories([global_id]) |
183 | 190 |
|
|
0 commit comments