@@ -55,10 +55,10 @@ def __init__(self, sess, env, brain_name, trainer_parameters, training, seed):
5555 self .training_buffer = Buffer ()
5656 self .is_continuous_action = (env .brains [brain_name ].vector_action_space_type == "continuous" )
5757 self .is_continuous_observation = (env .brains [brain_name ].vector_observation_space_type == "continuous" )
58- self .use_observations = (env .brains [brain_name ].number_visual_observations > 0 )
59- if self .use_observations :
58+ self .use_visual_observations = (env .brains [brain_name ].number_visual_observations > 0 )
59+ if self .use_visual_observations :
6060 logger .info ('Cannot use observations with imitation learning' )
61- self .use_states = (env .brains [brain_name ].vector_observation_space_size > 0 )
61+ self .use_vector_observations = (env .brains [brain_name ].vector_observation_space_size > 0 )
6262 self .summary_path = trainer_parameters ['summary_path' ]
6363 if not os .path .exists (self .summary_path ):
6464 os .makedirs (self .summary_path )
@@ -144,16 +144,15 @@ def take_action(self, all_brain_info: AllBrainInfo):
144144 agent_brain = all_brain_info [self .brain_name ]
145145 feed_dict = {self .model .dropout_rate : 1.0 , self .model .sequence_length : 1 }
146146
147- if self .use_observations :
147+ if self .use_visual_observations :
148148 for i , _ in enumerate (agent_brain .visual_observations ):
149149 feed_dict [self .model .visual_in [i ]] = agent_brain .visual_observations [i ]
150- if self .use_states :
150+ if self .use_vector_observations :
151151 feed_dict [self .model .vector_in ] = agent_brain .vector_observations
152152 if self .use_recurrent :
153153 if agent_brain .memories .shape [1 ] == 0 :
154154 agent_brain .memories = np .zeros ((len (agent_brain .agents ), self .m_size ))
155155 feed_dict [self .model .memory_in ] = agent_brain .memories
156- if self .use_recurrent :
157156 agent_action , memories = self .sess .run (self .inference_run_list , feed_dict )
158157 return agent_action , memories , None , None
159158 else :
@@ -192,11 +191,11 @@ def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take
192191 info_teacher_record , next_info_teacher_record = "true" , "true"
193192 if info_teacher_record == "true" and next_info_teacher_record == "true" :
194193 if not stored_info_teacher .local_done [idx ]:
195- if self .use_observations :
194+ if self .use_visual_observations :
196195 for i , _ in enumerate (stored_info_teacher .visual_observations ):
197196 self .training_buffer [agent_id ]['visual_observations%d' % i ]\
198197 .append (stored_info_teacher .visual_observations [i ][idx ])
199- if self .use_states :
198+ if self .use_vector_observations :
200199 self .training_buffer [agent_id ]['vector_observations' ]\
201200 .append (stored_info_teacher .vector_observations [idx ])
202201 if self .use_recurrent :
@@ -276,41 +275,38 @@ def update_model(self):
276275 """
277276 Uses training_buffer to update model.
278277 """
279-
280278 self .training_buffer .update_buffer .shuffle ()
281279 batch_losses = []
282280 for j in range (
283281 min (len (self .training_buffer .update_buffer ['actions' ]) // self .n_sequences , self .batches_per_epoch )):
284282 _buffer = self .training_buffer .update_buffer
285283 start = j * self .n_sequences
286284 end = (j + 1 ) * self .n_sequences
287- batch_states = np .array (_buffer ['vector_observations' ][start :end ])
288- batch_actions = np .array (_buffer ['actions' ][start :end ])
289285
290286 feed_dict = {self .model .dropout_rate : 0.5 ,
291287 self .model .batch_size : self .n_sequences ,
292288 self .model .sequence_length : self .sequence_length }
293289 if self .is_continuous_action :
294- feed_dict [self .model .true_action ] = batch_actions .reshape ([- 1 , self .brain .vector_action_space_size ])
295- else :
296- feed_dict [self .model .true_action ] = batch_actions .reshape ([- 1 ])
297- if not self .is_continuous_observation :
298- feed_dict [self .model .vector_in ] = batch_states .reshape ([- 1 , self .brain .num_stacked_vector_observations ])
290+ feed_dict [self .model .true_action ] = np .array (_buffer ['actions' ][start :end ]).\
291+ reshape ([- 1 , self .brain .vector_action_space_size ])
299292 else :
300- feed_dict [self .model .vector_in ] = batch_states .reshape ([- 1 , self .brain .vector_observation_space_size *
301- self .brain .num_stacked_vector_observations ])
302- if self .use_observations :
293+ feed_dict [self .model .true_action ] = np .array (_buffer ['actions' ][start :end ]).reshape ([- 1 ])
294+ if self .use_vector_observations :
295+ if not self .is_continuous_observation :
296+ feed_dict [self .model .vector_in ] = np .array (_buffer ['vector_observations' ][start :end ])\
297+ .reshape ([- 1 , self .brain .num_stacked_vector_observations ])
298+ else :
299+ feed_dict [self .model .vector_in ] = np .array (_buffer ['vector_observations' ][start :end ])\
300+ .reshape ([- 1 , self .brain .vector_observation_space_size * self .brain .num_stacked_vector_observations ])
301+ if self .use_visual_observations :
303302 for i , _ in enumerate (self .model .visual_in ):
304303 _obs = np .array (_buffer ['visual_observations%d' % i ][start :end ])
305- (_batch , _seq , _w , _h , _c ) = _obs .shape
306- feed_dict [self .model .visual_in [i ]] = _obs .reshape ([- 1 , _w , _h , _c ])
304+ feed_dict [self .model .visual_in [i ]] = _obs
307305 if self .use_recurrent :
308306 feed_dict [self .model .memory_in ] = np .zeros ([self .n_sequences , self .m_size ])
309-
310307 loss , _ = self .sess .run ([self .model .loss , self .model .update ], feed_dict = feed_dict )
311308 batch_losses .append (loss )
312309 if len (batch_losses ) > 0 :
313310 self .stats ['losses' ].append (np .mean (batch_losses ))
314311 else :
315312 self .stats ['losses' ].append (0 )
316-
0 commit comments