Skip to content

Commit 4ee265a

Browse files
committed
Fix for visual-only imitation learning
1 parent 7c16ce9 commit 4ee265a

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

python/unitytrainers/bc/trainer.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ def __init__(self, sess, env, brain_name, trainer_parameters, training, seed):
5555
self.training_buffer = Buffer()
5656
self.is_continuous_action = (env.brains[brain_name].vector_action_space_type == "continuous")
5757
self.is_continuous_observation = (env.brains[brain_name].vector_observation_space_type == "continuous")
58-
self.use_observations = (env.brains[brain_name].number_visual_observations > 0)
59-
if self.use_observations:
58+
self.use_visual_observations = (env.brains[brain_name].number_visual_observations > 0)
59+
if self.use_visual_observations:
6060
logger.info('Cannot use observations with imitation learning')
61-
self.use_states = (env.brains[brain_name].vector_observation_space_size > 0)
61+
self.use_vector_observations = (env.brains[brain_name].vector_observation_space_size > 0)
6262
self.summary_path = trainer_parameters['summary_path']
6363
if not os.path.exists(self.summary_path):
6464
os.makedirs(self.summary_path)
@@ -144,16 +144,15 @@ def take_action(self, all_brain_info: AllBrainInfo):
144144
agent_brain = all_brain_info[self.brain_name]
145145
feed_dict = {self.model.dropout_rate: 1.0, self.model.sequence_length: 1}
146146

147-
if self.use_observations:
147+
if self.use_visual_observations:
148148
for i, _ in enumerate(agent_brain.visual_observations):
149149
feed_dict[self.model.visual_in[i]] = agent_brain.visual_observations[i]
150-
if self.use_states:
150+
if self.use_vector_observations:
151151
feed_dict[self.model.vector_in] = agent_brain.vector_observations
152152
if self.use_recurrent:
153153
if agent_brain.memories.shape[1] == 0:
154154
agent_brain.memories = np.zeros((len(agent_brain.agents), self.m_size))
155155
feed_dict[self.model.memory_in] = agent_brain.memories
156-
if self.use_recurrent:
157156
agent_action, memories = self.sess.run(self.inference_run_list, feed_dict)
158157
return agent_action, memories, None, None
159158
else:
@@ -192,11 +191,11 @@ def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take
192191
info_teacher_record, next_info_teacher_record = "true", "true"
193192
if info_teacher_record == "true" and next_info_teacher_record == "true":
194193
if not stored_info_teacher.local_done[idx]:
195-
if self.use_observations:
194+
if self.use_visual_observations:
196195
for i, _ in enumerate(stored_info_teacher.visual_observations):
197196
self.training_buffer[agent_id]['visual_observations%d' % i]\
198197
.append(stored_info_teacher.visual_observations[i][idx])
199-
if self.use_states:
198+
if self.use_vector_observations:
200199
self.training_buffer[agent_id]['vector_observations']\
201200
.append(stored_info_teacher.vector_observations[idx])
202201
if self.use_recurrent:
@@ -276,41 +275,38 @@ def update_model(self):
276275
"""
277276
Uses training_buffer to update model.
278277
"""
279-
280278
self.training_buffer.update_buffer.shuffle()
281279
batch_losses = []
282280
for j in range(
283281
min(len(self.training_buffer.update_buffer['actions']) // self.n_sequences, self.batches_per_epoch)):
284282
_buffer = self.training_buffer.update_buffer
285283
start = j * self.n_sequences
286284
end = (j + 1) * self.n_sequences
287-
batch_states = np.array(_buffer['vector_observations'][start:end])
288-
batch_actions = np.array(_buffer['actions'][start:end])
289285

290286
feed_dict = {self.model.dropout_rate: 0.5,
291287
self.model.batch_size: self.n_sequences,
292288
self.model.sequence_length: self.sequence_length}
293289
if self.is_continuous_action:
294-
feed_dict[self.model.true_action] = batch_actions.reshape([-1, self.brain.vector_action_space_size])
295-
else:
296-
feed_dict[self.model.true_action] = batch_actions.reshape([-1])
297-
if not self.is_continuous_observation:
298-
feed_dict[self.model.vector_in] = batch_states.reshape([-1, self.brain.num_stacked_vector_observations])
290+
feed_dict[self.model.true_action] = np.array(_buffer['actions'][start:end]).\
291+
reshape([-1, self.brain.vector_action_space_size])
299292
else:
300-
feed_dict[self.model.vector_in] = batch_states.reshape([-1, self.brain.vector_observation_space_size *
301-
self.brain.num_stacked_vector_observations])
302-
if self.use_observations:
293+
feed_dict[self.model.true_action] = np.array(_buffer['actions'][start:end]).reshape([-1])
294+
if self.use_vector_observations:
295+
if not self.is_continuous_observation:
296+
feed_dict[self.model.vector_in] = np.array(_buffer['vector_observations'][start:end])\
297+
.reshape([-1, self.brain.num_stacked_vector_observations])
298+
else:
299+
feed_dict[self.model.vector_in] = np.array(_buffer['vector_observations'][start:end])\
300+
.reshape([-1, self.brain.vector_observation_space_size * self.brain.num_stacked_vector_observations])
301+
if self.use_visual_observations:
303302
for i, _ in enumerate(self.model.visual_in):
304303
_obs = np.array(_buffer['visual_observations%d' % i][start:end])
305-
(_batch, _seq, _w, _h, _c) = _obs.shape
306-
feed_dict[self.model.visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
304+
feed_dict[self.model.visual_in[i]] = _obs
307305
if self.use_recurrent:
308306
feed_dict[self.model.memory_in] = np.zeros([self.n_sequences, self.m_size])
309-
310307
loss, _ = self.sess.run([self.model.loss, self.model.update], feed_dict=feed_dict)
311308
batch_losses.append(loss)
312309
if len(batch_losses) > 0:
313310
self.stats['losses'].append(np.mean(batch_losses))
314311
else:
315312
self.stats['losses'].append(0)
316-

0 commit comments

Comments
 (0)