Skip to content

Commit 33328c3

Browse files
committed
Simplify approach
1 parent 582432f commit 33328c3

File tree

2 files changed

+69
-69
lines changed

2 files changed

+69
-69
lines changed

python/unityagents/brain.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
class BrainInfo:
55
def __init__(self, visual_observation, vector_observation, text_observations, memory=None,
6-
reward=None, agents=None, local_done=None,
7-
vector_action=None, text_action=None, max_reached=None):
6+
reward=None, agents=None, local_done=None,
7+
vector_action=None, text_action=None, max_reached=None):
88
"""
99
Describes experience at current step of all agents linked to a brain.
1010
"""
@@ -49,10 +49,10 @@ def __str__(self):
4949
Vector Action space type: {5}
5050
Vector Action space size (per agent): {6}
5151
Vector Action descriptions: {7}'''.format(self.brain_name,
52-
str(self.number_visual_observations),
53-
self.vector_observation_space_type,
54-
str(self.vector_observation_space_size),
55-
str(self.num_stacked_vector_observations),
56-
self.vector_action_space_type,
57-
str(self.vector_action_space_size),
58-
', '.join(self.vector_action_descriptions))
52+
str(self.number_visual_observations),
53+
self.vector_observation_space_type,
54+
str(self.vector_observation_space_size),
55+
str(self.num_stacked_vector_observations),
56+
self.vector_action_space_type,
57+
str(self.vector_action_space_size),
58+
', '.join(self.vector_action_descriptions))

python/unitytrainers/ppo/trainer.py

Lines changed: 60 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import tensorflow as tf
1010

11-
from unityagents import AllBrainInfo
11+
from unityagents import AllBrainInfo, BrainInfo
1212
from unitytrainers.buffer import Buffer
1313
from unitytrainers.ppo.models import PPOModel
1414
from unitytrainers.trainer import UnityTrainerException, Trainer
@@ -196,10 +196,51 @@ def take_action(self, all_brain_info: AllBrainInfo):
196196
else:
197197
return run_out[self.model.output], None, None, run_out
198198

199+
def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo:
200+
"""
201+
Constructs a BrainInfo which contains the lost recent previous experiences for all agents info
202+
which correspond to the agents in a provided next_info.
203+
:BrainInfo next_info: A t+1 BrainInfo.
204+
:return:
205+
"""
206+
visual_observations = [[]]
207+
vector_observations = []
208+
text_observations = []
209+
memories = []
210+
rewards = []
211+
local_dones = []
212+
max_reacheds = []
213+
agents = []
214+
prev_vector_actions = []
215+
prev_text_actions = []
216+
for agent_id in next_info.agents:
217+
agent_brain_info = self.training_buffer[agent_id].last_brain_info
218+
if agent_brain_info is None:
219+
agent_brain_info = next_info
220+
for i in range(len(next_info.visual_observations)):
221+
visual_observations[i].append(
222+
agent_brain_info.visual_observations[i][agent_brain_info.agents.index(agent_id)])
223+
vector_observations.append(agent_brain_info.vector_observations[agent_brain_info.agents.index(agent_id)])
224+
text_observations.append(agent_brain_info.text_observations[agent_brain_info.agents.index(agent_id)])
225+
if self.use_recurrent:
226+
memories.append(agent_brain_info.memories[agent_brain_info.agents.index(agent_id)])
227+
rewards.append(agent_brain_info.rewards[agent_brain_info.agents.index(agent_id)])
228+
local_dones.append(agent_brain_info.local_done[agent_brain_info.agents.index(agent_id)])
229+
max_reacheds.append(agent_brain_info.max_reached[agent_brain_info.agents.index(agent_id)])
230+
agents.append(agent_brain_info.agents[agent_brain_info.agents.index(agent_id)])
231+
prev_vector_actions.append(
232+
agent_brain_info.previous_vector_actions[agent_brain_info.agents.index(agent_id)])
233+
prev_text_actions.append(agent_brain_info.previous_text_actions[agent_brain_info.agents.index(agent_id)])
234+
curr_info = BrainInfo(visual_observations, vector_observations, text_observations, memories, rewards,
235+
agents,
236+
local_dones, prev_vector_actions, prev_text_actions, max_reacheds)
237+
return curr_info
238+
199239
def generate_intrinsic_rewards(self, curr_info, next_info):
200240
"""
201241
Generates intrinsic reward used for Curiosity-based training.
202-
:param next_info: Next BrainInfo.
242+
:BrainInfo curr_info: Current BrainInfo.
243+
:BrainInfo next_info: Next BrainInfo.
203244
:return: Intrinsic rewards for all agents.
204245
"""
205246
if self.use_curiosity:
@@ -209,64 +250,23 @@ def generate_intrinsic_rewards(self, curr_info, next_info):
209250
else:
210251
feed_dict[self.model.action_holder] = next_info.previous_vector_actions.flatten()
211252

212-
if curr_info.agents == next_info.agents:
213-
if self.use_visual_obs:
214-
for i in range(len(curr_info.visual_observations)):
215-
feed_dict[self.model.visual_in[i]] = curr_info.visual_observations[i]
216-
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i]
217-
if self.use_vector_obs:
218-
feed_dict[self.model.vector_in] = curr_info.vector_observations
219-
feed_dict[self.model.next_vector_in] = next_info.vector_observations
220-
if self.use_recurrent:
221-
if curr_info.memories.shape[1] == 0:
222-
curr_info.memories = np.zeros((len(curr_info.agents), self.m_size))
223-
feed_dict[self.model.memory_in] = curr_info.memories
224-
intrinsic_rewards = self.sess.run(self.model.intrinsic_reward,
225-
feed_dict=feed_dict) * float(self.has_updated)
226-
return intrinsic_rewards
227-
else:
228-
agent_index_to_ignore = []
229-
for agent_index, agent_id in enumerate(next_info.agents):
230-
if self.training_buffer[agent_id].last_brain_info is None:
231-
agent_index_to_ignore.append(agent_index)
232-
if self.use_visual_obs:
233-
for i in range(len(next_info.visual_observations)):
234-
tmp = []
235-
for agent_id in next_info.agents:
236-
agent_brain_info = self.training_buffer[agent_id].last_brain_info
237-
if agent_brain_info is None:
238-
agent_brain_info = next_info
239-
agent_obs = agent_brain_info.visual_observations[i][agent_brain_info.agents.index(agent_id)]
240-
tmp += [agent_obs]
241-
feed_dict[self.model.visual_in[i]] = np.array(tmp)
242-
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i]
243-
if self.use_vector_obs:
244-
tmp = []
245-
for agent_id in next_info.agents:
246-
agent_brain_info = self.training_buffer[agent_id].last_brain_info
247-
if agent_brain_info is None:
248-
agent_brain_info = next_info
249-
agent_obs = agent_brain_info.vector_observations[agent_brain_info.agents.index(agent_id)]
250-
tmp += [agent_obs]
251-
feed_dict[self.model.vector_in] = np.array(tmp)
252-
feed_dict[self.model.next_vector_in] = next_info.vector_observations
253-
if self.use_recurrent:
254-
tmp = []
255-
for agent_id in next_info.agents:
256-
agent_brain_info = self.training_buffer[agent_id].last_brain_info
257-
if agent_brain_info is None:
258-
agent_brain_info = next_info
259-
if agent_brain_info.memories.shape[1] == 0:
260-
agent_obs = np.zeros(self.m_size)
261-
else:
262-
agent_obs = agent_brain_info.memories[agent_brain_info.agents.index(agent_id)]
263-
tmp += [agent_obs]
264-
feed_dict[self.model.memory_in] = np.array(tmp)
265-
intrinsic_rewards = self.sess.run(self.model.intrinsic_reward,
266-
feed_dict=feed_dict) * float(self.has_updated)
267-
for index in agent_index_to_ignore:
268-
intrinsic_rewards[index] = 0
269-
return intrinsic_rewards
253+
if curr_info.agents != next_info.agents:
254+
curr_info = self.construct_curr_info(next_info)
255+
256+
if self.use_visual_obs:
257+
for i in range(len(curr_info.visual_observations)):
258+
feed_dict[self.model.visual_in[i]] = curr_info.visual_observations[i]
259+
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i]
260+
if self.use_vector_obs:
261+
feed_dict[self.model.vector_in] = curr_info.vector_observations
262+
feed_dict[self.model.next_vector_in] = next_info.vector_observations
263+
if self.use_recurrent:
264+
if curr_info.memories.shape[1] == 0:
265+
curr_info.memories = np.zeros((len(curr_info.agents), self.m_size))
266+
feed_dict[self.model.memory_in] = curr_info.memories
267+
intrinsic_rewards = self.sess.run(self.model.intrinsic_reward,
268+
feed_dict=feed_dict) * float(self.has_updated)
269+
return intrinsic_rewards
270270
else:
271271
return None
272272

0 commit comments

Comments
 (0)