@@ -196,60 +196,77 @@ def take_action(self, all_brain_info: AllBrainInfo):
196196 else :
197197 return run_out [self .model .output ], None , None , run_out
198198
199- def generate_intrinsic_rewards (self , next_info ):
199+ def generate_intrinsic_rewards (self , curr_info , next_info ):
200200 """
201201 Generates intrinsic reward used for Curiosity-based training.
202202 :param next_info: Next BrainInfo.
203203 :return: Intrinsic rewards for all agents.
204204 """
205205 if self .use_curiosity :
206- agent_index_to_ignore = []
207- for agent_index , agent_id in enumerate (next_info .agents ):
208- if self .training_buffer [agent_id ].last_brain_info is None :
209- agent_index_to_ignore .append (agent_index )
210206 feed_dict = {self .model .batch_size : len (next_info .vector_observations ), self .model .sequence_length : 1 }
211207 if self .is_continuous_action :
212208 feed_dict [self .model .output ] = next_info .previous_vector_actions
213209 else :
214210 feed_dict [self .model .action_holder ] = next_info .previous_vector_actions .flatten ()
215- if self .use_visual_obs :
216- for i in range (len (next_info .visual_observations )):
211+
212+ if curr_info .agents == next_info .agents :
213+ if self .use_visual_obs :
214+ for i in range (len (curr_info .visual_observations )):
215+ feed_dict [self .model .visual_in [i ]] = curr_info .visual_observations [i ]
216+ feed_dict [self .model .next_visual_in [i ]] = next_info .visual_observations [i ]
217+ if self .use_vector_obs :
218+ feed_dict [self .model .vector_in ] = curr_info .vector_observations
219+ feed_dict [self .model .next_vector_in ] = next_info .vector_observations
220+ if self .use_recurrent :
221+ if curr_info .memories .shape [1 ] == 0 :
222+ curr_info .memories = np .zeros ((len (curr_info .agents ), self .m_size ))
223+ feed_dict [self .model .memory_in ] = curr_info .memories
224+ intrinsic_rewards = self .sess .run (self .model .intrinsic_reward ,
225+ feed_dict = feed_dict ) * float (self .has_updated )
226+ return intrinsic_rewards
227+ else :
228+ agent_index_to_ignore = []
229+ for agent_index , agent_id in enumerate (next_info .agents ):
230+ if self .training_buffer [agent_id ].last_brain_info is None :
231+ agent_index_to_ignore .append (agent_index )
232+ if self .use_visual_obs :
233+ for i in range (len (next_info .visual_observations )):
234+ tmp = []
235+ for agent_id in next_info .agents :
236+ agent_brain_info = self .training_buffer [agent_id ].last_brain_info
237+ if agent_brain_info is None :
238+ agent_brain_info = next_info
239+ agent_obs = agent_brain_info .visual_observations [i ][agent_brain_info .agents .index (agent_id )]
240+ tmp += [agent_obs ]
241+ feed_dict [self .model .visual_in [i ]] = np .array (tmp )
242+ feed_dict [self .model .next_visual_in [i ]] = next_info .visual_observations [i ]
243+ if self .use_vector_obs :
217244 tmp = []
218245 for agent_id in next_info .agents :
219246 agent_brain_info = self .training_buffer [agent_id ].last_brain_info
220247 if agent_brain_info is None :
221248 agent_brain_info = next_info
222- agent_obs = agent_brain_info .visual_observations [ i ] [agent_brain_info .agents .index (agent_id )]
249+ agent_obs = agent_brain_info .vector_observations [agent_brain_info .agents .index (agent_id )]
223250 tmp += [agent_obs ]
224- feed_dict [self .model .visual_in [i ]] = np .array (tmp )
225- feed_dict [self .model .next_visual_in [i ]] = next_info .visual_observations [i ]
226- if self .use_vector_obs :
227- tmp = []
228- for agent_id in next_info .agents :
229- agent_brain_info = self .training_buffer [agent_id ].last_brain_info
230- if agent_brain_info is None :
231- agent_brain_info = next_info
232- agent_obs = agent_brain_info .vector_observations [agent_brain_info .agents .index (agent_id )]
233- tmp += [agent_obs ]
234- feed_dict [self .model .vector_in ] = np .array (tmp )
235- feed_dict [self .model .next_vector_in ] = next_info .vector_observations
236- if self .use_recurrent :
237- tmp = []
238- for agent_id in next_info .agents :
239- agent_brain_info = self .training_buffer [agent_id ].last_brain_info
240- if agent_brain_info is None :
241- agent_brain_info = next_info
242- if agent_brain_info .memories .shape [1 ] == 0 :
243- agent_obs = np .zeros (self .m_size )
244- else :
245- agent_obs = agent_brain_info .memories [agent_brain_info .agents .index (agent_id )]
246- tmp += [agent_obs ]
247- feed_dict [self .model .memory_in ] = np .array (tmp )
248- intrinsic_rewards = self .sess .run (self .model .intrinsic_reward ,
249- feed_dict = feed_dict ) * float (self .has_updated )
250- for index in agent_index_to_ignore :
251- intrinsic_rewards [index ] = 0
252- return intrinsic_rewards
251+ feed_dict [self .model .vector_in ] = np .array (tmp )
252+ feed_dict [self .model .next_vector_in ] = next_info .vector_observations
253+ if self .use_recurrent :
254+ tmp = []
255+ for agent_id in next_info .agents :
256+ agent_brain_info = self .training_buffer [agent_id ].last_brain_info
257+ if agent_brain_info is None :
258+ agent_brain_info = next_info
259+ if agent_brain_info .memories .shape [1 ] == 0 :
260+ agent_obs = np .zeros (self .m_size )
261+ else :
262+ agent_obs = agent_brain_info .memories [agent_brain_info .agents .index (agent_id )]
263+ tmp += [agent_obs ]
264+ feed_dict [self .model .memory_in ] = np .array (tmp )
265+ intrinsic_rewards = self .sess .run (self .model .intrinsic_reward ,
266+ feed_dict = feed_dict ) * float (self .has_updated )
267+ for index in agent_index_to_ignore :
268+ intrinsic_rewards [index ] = 0
269+ return intrinsic_rewards
253270 else :
254271 return None
255272
@@ -286,13 +303,11 @@ def add_experiences(self, curr_all_info: AllBrainInfo, next_all_info: AllBrainIn
286303 curr_info = curr_all_info [self .brain_name ]
287304 next_info = next_all_info [self .brain_name ]
288305
289- # intrinsic_rewards = self.generate_intrinsic_rewards(curr_info, next_info)
290-
291306 for agent_id in curr_info .agents :
292307 self .training_buffer [agent_id ].last_brain_info = curr_info
293308 self .training_buffer [agent_id ].last_take_action_outputs = take_action_outputs
294309
295- intrinsic_rewards = self .generate_intrinsic_rewards (next_info )
310+ intrinsic_rewards = self .generate_intrinsic_rewards (curr_info , next_info )
296311
297312 for agent_id in next_info .agents :
298313 stored_info = self .training_buffer [agent_id ].last_brain_info
0 commit comments