|
8 | 8 | import numpy as np |
9 | 9 | import tensorflow as tf |
10 | 10 |
|
11 | | -from unityagents import AllBrainInfo |
| 11 | +from unityagents import AllBrainInfo, BrainInfo |
12 | 12 | from unitytrainers.buffer import Buffer |
13 | 13 | from unitytrainers.ppo.models import PPOModel |
14 | 14 | from unitytrainers.trainer import UnityTrainerException, Trainer |
@@ -196,22 +196,61 @@ def take_action(self, all_brain_info: AllBrainInfo): |
196 | 196 | else: |
197 | 197 | return run_out[self.model.output], None, None, run_out |
198 | 198 |
|
| 199 | + def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo: |
| 200 | + """ |
| 201 | + Constructs a BrainInfo which contains the most recent previous experiences for all agents info |
| 202 | + which correspond to the agents in a provided next_info. |
| 203 | + :BrainInfo next_info: A t+1 BrainInfo. |
| 204 | + :return: curr_info: Reconstructed BrainInfo to match agents of next_info. |
| 205 | + """ |
| 206 | + visual_observations = [[]] |
| 207 | + vector_observations = [] |
| 208 | + text_observations = [] |
| 209 | + memories = [] |
| 210 | + rewards = [] |
| 211 | + local_dones = [] |
| 212 | + max_reacheds = [] |
| 213 | + agents = [] |
| 214 | + prev_vector_actions = [] |
| 215 | + prev_text_actions = [] |
| 216 | + for agent_id in next_info.agents: |
| 217 | + agent_brain_info = self.training_buffer[agent_id].last_brain_info |
| 218 | + agent_index = agent_brain_info.agents.index(agent_id) |
| 219 | + if agent_brain_info is None: |
| 220 | + agent_brain_info = next_info |
| 221 | + for i in range(len(next_info.visual_observations)): |
| 222 | + visual_observations[i].append(agent_brain_info.visual_observations[i][agent_index]) |
| 223 | + vector_observations.append(agent_brain_info.vector_observations[agent_index]) |
| 224 | + text_observations.append(agent_brain_info.text_observations[agent_index]) |
| 225 | + if self.use_recurrent: |
| 226 | + memories.append(agent_brain_info.memories[agent_index]) |
| 227 | + rewards.append(agent_brain_info.rewards[agent_index]) |
| 228 | + local_dones.append(agent_brain_info.local_done[agent_index]) |
| 229 | + max_reacheds.append(agent_brain_info.max_reached[agent_index]) |
| 230 | + agents.append(agent_brain_info.agents[agent_index]) |
| 231 | + prev_vector_actions.append(agent_brain_info.previous_vector_actions[agent_index]) |
| 232 | + prev_text_actions.append(agent_brain_info.previous_text_actions[agent_index]) |
| 233 | + curr_info = BrainInfo(visual_observations, vector_observations, text_observations, memories, rewards, |
| 234 | + agents, local_dones, prev_vector_actions, prev_text_actions, max_reacheds) |
| 235 | + return curr_info |
| 236 | + |
199 | 237 | def generate_intrinsic_rewards(self, curr_info, next_info): |
200 | 238 | """ |
201 | 239 | Generates intrinsic reward used for Curiosity-based training. |
202 | | - :param curr_info: Current BrainInfo. |
203 | | - :param next_info: Next BrainInfo. |
| 240 | + :BrainInfo curr_info: Current BrainInfo. |
| 241 | + :BrainInfo next_info: Next BrainInfo. |
204 | 242 | :return: Intrinsic rewards for all agents. |
205 | 243 | """ |
206 | 244 | if self.use_curiosity: |
207 | | - if curr_info.agents != next_info.agents: |
208 | | - raise UnityTrainerException("Training with Curiosity-driven exploration" |
209 | | - " and On-Demand Decision making is currently not supported.") |
210 | | - feed_dict = {self.model.batch_size: len(curr_info.vector_observations), self.model.sequence_length: 1} |
| 245 | + feed_dict = {self.model.batch_size: len(next_info.vector_observations), self.model.sequence_length: 1} |
211 | 246 | if self.is_continuous_action: |
212 | 247 | feed_dict[self.model.output] = next_info.previous_vector_actions |
213 | 248 | else: |
214 | 249 | feed_dict[self.model.action_holder] = next_info.previous_vector_actions.flatten() |
| 250 | + |
| 251 | + if curr_info.agents != next_info.agents: |
| 252 | + curr_info = self.construct_curr_info(next_info) |
| 253 | + |
215 | 254 | if self.use_visual_obs: |
216 | 255 | for i in range(len(curr_info.visual_observations)): |
217 | 256 | feed_dict[self.model.visual_in[i]] = curr_info.visual_observations[i] |
@@ -262,12 +301,12 @@ def add_experiences(self, curr_all_info: AllBrainInfo, next_all_info: AllBrainIn |
262 | 301 | curr_info = curr_all_info[self.brain_name] |
263 | 302 | next_info = next_all_info[self.brain_name] |
264 | 303 |
|
265 | | - intrinsic_rewards = self.generate_intrinsic_rewards(curr_info, next_info) |
266 | | - |
267 | 304 | for agent_id in curr_info.agents: |
268 | 305 | self.training_buffer[agent_id].last_brain_info = curr_info |
269 | 306 | self.training_buffer[agent_id].last_take_action_outputs = take_action_outputs |
270 | 307 |
|
| 308 | + intrinsic_rewards = self.generate_intrinsic_rewards(curr_info, next_info) |
| 309 | + |
271 | 310 | for agent_id in next_info.agents: |
272 | 311 | stored_info = self.training_buffer[agent_id].last_brain_info |
273 | 312 | stored_take_action_outputs = self.training_buffer[agent_id].last_take_action_outputs |
|
0 commit comments