88import numpy as np
99import tensorflow as tf
1010
11- from unityagents import AllBrainInfo
11+ from unityagents import AllBrainInfo , BrainInfo
1212from unitytrainers .buffer import Buffer
1313from unitytrainers .ppo .models import PPOModel
1414from unitytrainers .trainer import UnityTrainerException , Trainer
@@ -196,10 +196,51 @@ def take_action(self, all_brain_info: AllBrainInfo):
196196 else :
197197 return run_out [self .model .output ], None , None , run_out
198198
199+ def construct_curr_info (self , next_info : BrainInfo ) -> BrainInfo :
200+ """
201+ Constructs a BrainInfo which contains the lost recent previous experiences for all agents info
202+ which correspond to the agents in a provided next_info.
203+ :BrainInfo next_info: A t+1 BrainInfo.
204+ :return:
205+ """
206+ visual_observations = [[]]
207+ vector_observations = []
208+ text_observations = []
209+ memories = []
210+ rewards = []
211+ local_dones = []
212+ max_reacheds = []
213+ agents = []
214+ prev_vector_actions = []
215+ prev_text_actions = []
216+ for agent_id in next_info .agents :
217+ agent_brain_info = self .training_buffer [agent_id ].last_brain_info
218+ if agent_brain_info is None :
219+ agent_brain_info = next_info
220+ for i in range (len (next_info .visual_observations )):
221+ visual_observations [i ].append (
222+ agent_brain_info .visual_observations [i ][agent_brain_info .agents .index (agent_id )])
223+ vector_observations .append (agent_brain_info .vector_observations [agent_brain_info .agents .index (agent_id )])
224+ text_observations .append (agent_brain_info .text_observations [agent_brain_info .agents .index (agent_id )])
225+ if self .use_recurrent :
226+ memories .append (agent_brain_info .memories [agent_brain_info .agents .index (agent_id )])
227+ rewards .append (agent_brain_info .rewards [agent_brain_info .agents .index (agent_id )])
228+ local_dones .append (agent_brain_info .local_done [agent_brain_info .agents .index (agent_id )])
229+ max_reacheds .append (agent_brain_info .max_reached [agent_brain_info .agents .index (agent_id )])
230+ agents .append (agent_brain_info .agents [agent_brain_info .agents .index (agent_id )])
231+ prev_vector_actions .append (
232+ agent_brain_info .previous_vector_actions [agent_brain_info .agents .index (agent_id )])
233+ prev_text_actions .append (agent_brain_info .previous_text_actions [agent_brain_info .agents .index (agent_id )])
234+ curr_info = BrainInfo (visual_observations , vector_observations , text_observations , memories , rewards ,
235+ agents ,
236+ local_dones , prev_vector_actions , prev_text_actions , max_reacheds )
237+ return curr_info
238+
199239 def generate_intrinsic_rewards (self , curr_info , next_info ):
200240 """
201241 Generates intrinsic reward used for Curiosity-based training.
202- :param next_info: Next BrainInfo.
242+ :BrainInfo curr_info: Current BrainInfo.
243+ :BrainInfo next_info: Next BrainInfo.
203244 :return: Intrinsic rewards for all agents.
204245 """
205246 if self .use_curiosity :
@@ -209,64 +250,23 @@ def generate_intrinsic_rewards(self, curr_info, next_info):
209250 else :
210251 feed_dict [self .model .action_holder ] = next_info .previous_vector_actions .flatten ()
211252
212- if curr_info .agents == next_info .agents :
213- if self .use_visual_obs :
214- for i in range (len (curr_info .visual_observations )):
215- feed_dict [self .model .visual_in [i ]] = curr_info .visual_observations [i ]
216- feed_dict [self .model .next_visual_in [i ]] = next_info .visual_observations [i ]
217- if self .use_vector_obs :
218- feed_dict [self .model .vector_in ] = curr_info .vector_observations
219- feed_dict [self .model .next_vector_in ] = next_info .vector_observations
220- if self .use_recurrent :
221- if curr_info .memories .shape [1 ] == 0 :
222- curr_info .memories = np .zeros ((len (curr_info .agents ), self .m_size ))
223- feed_dict [self .model .memory_in ] = curr_info .memories
224- intrinsic_rewards = self .sess .run (self .model .intrinsic_reward ,
225- feed_dict = feed_dict ) * float (self .has_updated )
226- return intrinsic_rewards
227- else :
228- agent_index_to_ignore = []
229- for agent_index , agent_id in enumerate (next_info .agents ):
230- if self .training_buffer [agent_id ].last_brain_info is None :
231- agent_index_to_ignore .append (agent_index )
232- if self .use_visual_obs :
233- for i in range (len (next_info .visual_observations )):
234- tmp = []
235- for agent_id in next_info .agents :
236- agent_brain_info = self .training_buffer [agent_id ].last_brain_info
237- if agent_brain_info is None :
238- agent_brain_info = next_info
239- agent_obs = agent_brain_info .visual_observations [i ][agent_brain_info .agents .index (agent_id )]
240- tmp += [agent_obs ]
241- feed_dict [self .model .visual_in [i ]] = np .array (tmp )
242- feed_dict [self .model .next_visual_in [i ]] = next_info .visual_observations [i ]
243- if self .use_vector_obs :
244- tmp = []
245- for agent_id in next_info .agents :
246- agent_brain_info = self .training_buffer [agent_id ].last_brain_info
247- if agent_brain_info is None :
248- agent_brain_info = next_info
249- agent_obs = agent_brain_info .vector_observations [agent_brain_info .agents .index (agent_id )]
250- tmp += [agent_obs ]
251- feed_dict [self .model .vector_in ] = np .array (tmp )
252- feed_dict [self .model .next_vector_in ] = next_info .vector_observations
253- if self .use_recurrent :
254- tmp = []
255- for agent_id in next_info .agents :
256- agent_brain_info = self .training_buffer [agent_id ].last_brain_info
257- if agent_brain_info is None :
258- agent_brain_info = next_info
259- if agent_brain_info .memories .shape [1 ] == 0 :
260- agent_obs = np .zeros (self .m_size )
261- else :
262- agent_obs = agent_brain_info .memories [agent_brain_info .agents .index (agent_id )]
263- tmp += [agent_obs ]
264- feed_dict [self .model .memory_in ] = np .array (tmp )
265- intrinsic_rewards = self .sess .run (self .model .intrinsic_reward ,
266- feed_dict = feed_dict ) * float (self .has_updated )
267- for index in agent_index_to_ignore :
268- intrinsic_rewards [index ] = 0
269- return intrinsic_rewards
253+ if curr_info .agents != next_info .agents :
254+ curr_info = self .construct_curr_info (next_info )
255+
256+ if self .use_visual_obs :
257+ for i in range (len (curr_info .visual_observations )):
258+ feed_dict [self .model .visual_in [i ]] = curr_info .visual_observations [i ]
259+ feed_dict [self .model .next_visual_in [i ]] = next_info .visual_observations [i ]
260+ if self .use_vector_obs :
261+ feed_dict [self .model .vector_in ] = curr_info .vector_observations
262+ feed_dict [self .model .next_vector_in ] = next_info .vector_observations
263+ if self .use_recurrent :
264+ if curr_info .memories .shape [1 ] == 0 :
265+ curr_info .memories = np .zeros ((len (curr_info .agents ), self .m_size ))
266+ feed_dict [self .model .memory_in ] = curr_info .memories
267+ intrinsic_rewards = self .sess .run (self .model .intrinsic_reward ,
268+ feed_dict = feed_dict ) * float (self .has_updated )
269+ return intrinsic_rewards
270270 else :
271271 return None
272272
0 commit comments