@@ -196,32 +196,59 @@ def take_action(self, all_brain_info: AllBrainInfo):
196196 else :
197197 return run_out [self .model .output ], None , None , run_out
198198
199- def generate_intrinsic_rewards (self , curr_info , next_info ):
199+ def generate_intrinsic_rewards (self , next_info ):
200200 """
201201 Generates intrinsic reward used for Curiosity-based training.
202- :param curr_info: Current BrainInfo.
203202 :param next_info: Next BrainInfo.
204203 :return: Intrinsic rewards for all agents.
205204 """
206205 if self .use_curiosity :
207- feed_dict = {self .model .batch_size : len (curr_info .vector_observations ), self .model .sequence_length : 1 }
206+ agent_index_to_ignore = []
207+ for agent_index , agent_id in enumerate (next_info .agents ):
208+ if self .training_buffer [agent_id ].last_brain_info is None :
209+ agent_index_to_ignore .append (agent_index )
210+ feed_dict = {self .model .batch_size : len (next_info .vector_observations ), self .model .sequence_length : 1 }
208211 if self .is_continuous_action :
209212 feed_dict [self .model .output ] = next_info .previous_vector_actions
210213 else :
211214 feed_dict [self .model .action_holder ] = next_info .previous_vector_actions .flatten ()
212215 if self .use_visual_obs :
213- for i in range (len (curr_info .visual_observations )):
214- feed_dict [self .model .visual_in [i ]] = curr_info .visual_observations [i ]
216+ for i in range (len (next_info .visual_observations )):
217+ tmp = []
218+ for agent_id in next_info .agents :
219+ agent_brain_info = self .training_buffer [agent_id ].last_brain_info
220+ if agent_brain_info is None :
221+ agent_brain_info = next_info
222+ agent_obs = agent_brain_info .visual_observations [i ][agent_brain_info .agents .index (agent_id )]
223+ tmp += [agent_obs ]
224+ feed_dict [self .model .visual_in [i ]] = np .array (tmp )
215225 feed_dict [self .model .next_visual_in [i ]] = next_info .visual_observations [i ]
216226 if self .use_vector_obs :
217- feed_dict [self .model .vector_in ] = curr_info .vector_observations
227+ tmp = []
228+ for agent_id in next_info .agents :
229+ agent_brain_info = self .training_buffer [agent_id ].last_brain_info
230+ if agent_brain_info is None :
231+ agent_brain_info = next_info
232+ agent_obs = agent_brain_info .vector_observations [agent_brain_info .agents .index (agent_id )]
233+ tmp += [agent_obs ]
234+ feed_dict [self .model .vector_in ] = np .array (tmp )
218235 feed_dict [self .model .next_vector_in ] = next_info .vector_observations
219236 if self .use_recurrent :
220- if curr_info .memories .shape [1 ] == 0 :
221- curr_info .memories = np .zeros ((len (curr_info .agents ), self .m_size ))
222- feed_dict [self .model .memory_in ] = curr_info .memories
237+ tmp = []
238+ for agent_id in next_info .agents :
239+ agent_brain_info = self .training_buffer [agent_id ].last_brain_info
240+ if agent_brain_info is None :
241+ agent_brain_info = next_info
242+ if agent_brain_info .memories .shape [1 ] == 0 :
243+ agent_obs = np .zeros (self .m_size )
244+ else :
245+ agent_obs = agent_brain_info .memories [agent_brain_info .agents .index (agent_id )]
246+ tmp += [agent_obs ]
247+ feed_dict [self .model .memory_in ] = np .array (tmp )
223248 intrinsic_rewards = self .sess .run (self .model .intrinsic_reward ,
224249 feed_dict = feed_dict ) * float (self .has_updated )
250+ for index in agent_index_to_ignore :
251+ intrinsic_rewards [index ] = 0
225252 return intrinsic_rewards
226253 else :
227254 return None
@@ -259,12 +286,14 @@ def add_experiences(self, curr_all_info: AllBrainInfo, next_all_info: AllBrainIn
259286 curr_info = curr_all_info [self .brain_name ]
260287 next_info = next_all_info [self .brain_name ]
261288
262- intrinsic_rewards = self .generate_intrinsic_rewards (curr_info , next_info )
289+ # intrinsic_rewards = self.generate_intrinsic_rewards(curr_info, next_info)
263290
264291 for agent_id in curr_info .agents :
265292 self .training_buffer [agent_id ].last_brain_info = curr_info
266293 self .training_buffer [agent_id ].last_take_action_outputs = take_action_outputs
267294
295+ intrinsic_rewards = self .generate_intrinsic_rewards (next_info )
296+
268297 for agent_id in next_info .agents :
269298 stored_info = self .training_buffer [agent_id ].last_brain_info
270299 stored_take_action_outputs = self .training_buffer [agent_id ].last_take_action_outputs
0 commit comments