Skip to content

Commit 582432f

Browse files
committed
Use switch between old and new behavior
1 parent 5d07398 commit 582432f

File tree

1 file changed

+55
-40
lines changed

1 file changed

+55
-40
lines changed

python/unitytrainers/ppo/trainer.py

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -196,60 +196,77 @@ def take_action(self, all_brain_info: AllBrainInfo):
196196
else:
197197
return run_out[self.model.output], None, None, run_out
198198

199-
def generate_intrinsic_rewards(self, next_info):
199+
def generate_intrinsic_rewards(self, curr_info, next_info):
200200
"""
201201
Generates intrinsic reward used for Curiosity-based training.
202202
:param next_info: Next BrainInfo.
203203
:return: Intrinsic rewards for all agents.
204204
"""
205205
if self.use_curiosity:
206-
agent_index_to_ignore = []
207-
for agent_index, agent_id in enumerate(next_info.agents):
208-
if self.training_buffer[agent_id].last_brain_info is None:
209-
agent_index_to_ignore.append(agent_index)
210206
feed_dict = {self.model.batch_size: len(next_info.vector_observations), self.model.sequence_length: 1}
211207
if self.is_continuous_action:
212208
feed_dict[self.model.output] = next_info.previous_vector_actions
213209
else:
214210
feed_dict[self.model.action_holder] = next_info.previous_vector_actions.flatten()
215-
if self.use_visual_obs:
216-
for i in range(len(next_info.visual_observations)):
211+
212+
if curr_info.agents == next_info.agents:
213+
if self.use_visual_obs:
214+
for i in range(len(curr_info.visual_observations)):
215+
feed_dict[self.model.visual_in[i]] = curr_info.visual_observations[i]
216+
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i]
217+
if self.use_vector_obs:
218+
feed_dict[self.model.vector_in] = curr_info.vector_observations
219+
feed_dict[self.model.next_vector_in] = next_info.vector_observations
220+
if self.use_recurrent:
221+
if curr_info.memories.shape[1] == 0:
222+
curr_info.memories = np.zeros((len(curr_info.agents), self.m_size))
223+
feed_dict[self.model.memory_in] = curr_info.memories
224+
intrinsic_rewards = self.sess.run(self.model.intrinsic_reward,
225+
feed_dict=feed_dict) * float(self.has_updated)
226+
return intrinsic_rewards
227+
else:
228+
agent_index_to_ignore = []
229+
for agent_index, agent_id in enumerate(next_info.agents):
230+
if self.training_buffer[agent_id].last_brain_info is None:
231+
agent_index_to_ignore.append(agent_index)
232+
if self.use_visual_obs:
233+
for i in range(len(next_info.visual_observations)):
234+
tmp = []
235+
for agent_id in next_info.agents:
236+
agent_brain_info = self.training_buffer[agent_id].last_brain_info
237+
if agent_brain_info is None:
238+
agent_brain_info = next_info
239+
agent_obs = agent_brain_info.visual_observations[i][agent_brain_info.agents.index(agent_id)]
240+
tmp += [agent_obs]
241+
feed_dict[self.model.visual_in[i]] = np.array(tmp)
242+
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i]
243+
if self.use_vector_obs:
217244
tmp = []
218245
for agent_id in next_info.agents:
219246
agent_brain_info = self.training_buffer[agent_id].last_brain_info
220247
if agent_brain_info is None:
221248
agent_brain_info = next_info
222-
agent_obs = agent_brain_info.visual_observations[i][agent_brain_info.agents.index(agent_id)]
249+
agent_obs = agent_brain_info.vector_observations[agent_brain_info.agents.index(agent_id)]
223250
tmp += [agent_obs]
224-
feed_dict[self.model.visual_in[i]] = np.array(tmp)
225-
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i]
226-
if self.use_vector_obs:
227-
tmp = []
228-
for agent_id in next_info.agents:
229-
agent_brain_info = self.training_buffer[agent_id].last_brain_info
230-
if agent_brain_info is None:
231-
agent_brain_info = next_info
232-
agent_obs = agent_brain_info.vector_observations[agent_brain_info.agents.index(agent_id)]
233-
tmp += [agent_obs]
234-
feed_dict[self.model.vector_in] = np.array(tmp)
235-
feed_dict[self.model.next_vector_in] = next_info.vector_observations
236-
if self.use_recurrent:
237-
tmp = []
238-
for agent_id in next_info.agents:
239-
agent_brain_info = self.training_buffer[agent_id].last_brain_info
240-
if agent_brain_info is None:
241-
agent_brain_info = next_info
242-
if agent_brain_info.memories.shape[1] == 0:
243-
agent_obs = np.zeros(self.m_size)
244-
else:
245-
agent_obs = agent_brain_info.memories[agent_brain_info.agents.index(agent_id)]
246-
tmp += [agent_obs]
247-
feed_dict[self.model.memory_in] = np.array(tmp)
248-
intrinsic_rewards = self.sess.run(self.model.intrinsic_reward,
249-
feed_dict=feed_dict) * float(self.has_updated)
250-
for index in agent_index_to_ignore:
251-
intrinsic_rewards[index] = 0
252-
return intrinsic_rewards
251+
feed_dict[self.model.vector_in] = np.array(tmp)
252+
feed_dict[self.model.next_vector_in] = next_info.vector_observations
253+
if self.use_recurrent:
254+
tmp = []
255+
for agent_id in next_info.agents:
256+
agent_brain_info = self.training_buffer[agent_id].last_brain_info
257+
if agent_brain_info is None:
258+
agent_brain_info = next_info
259+
if agent_brain_info.memories.shape[1] == 0:
260+
agent_obs = np.zeros(self.m_size)
261+
else:
262+
agent_obs = agent_brain_info.memories[agent_brain_info.agents.index(agent_id)]
263+
tmp += [agent_obs]
264+
feed_dict[self.model.memory_in] = np.array(tmp)
265+
intrinsic_rewards = self.sess.run(self.model.intrinsic_reward,
266+
feed_dict=feed_dict) * float(self.has_updated)
267+
for index in agent_index_to_ignore:
268+
intrinsic_rewards[index] = 0
269+
return intrinsic_rewards
253270
else:
254271
return None
255272

@@ -286,13 +303,11 @@ def add_experiences(self, curr_all_info: AllBrainInfo, next_all_info: AllBrainIn
286303
curr_info = curr_all_info[self.brain_name]
287304
next_info = next_all_info[self.brain_name]
288305

289-
# intrinsic_rewards = self.generate_intrinsic_rewards(curr_info, next_info)
290-
291306
for agent_id in curr_info.agents:
292307
self.training_buffer[agent_id].last_brain_info = curr_info
293308
self.training_buffer[agent_id].last_take_action_outputs = take_action_outputs
294309

295-
intrinsic_rewards = self.generate_intrinsic_rewards(next_info)
310+
intrinsic_rewards = self.generate_intrinsic_rewards(curr_info, next_info)
296311

297312
for agent_id in next_info.agents:
298313
stored_info = self.training_buffer[agent_id].last_brain_info

0 commit comments

Comments
 (0)