Skip to content

Commit 807e260

Browse files
author
Ervin Teng
committed
Make sure all tests pass on BC
1 parent c4f87fe commit 807e260

File tree

2 files changed

+4
-14
lines changed

2 files changed

+4
-14
lines changed

ml-agents/mlagents/trainers/bc/policy.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,14 @@ def update(self, mini_batch, num_sequences):
7979
self.model.sequence_length: self.sequence_length,
8080
}
8181
if self.use_continuous_act:
82-
feed_dict[self.model.true_action] = mini_batch["actions"].reshape(
83-
[-1, self.brain.vector_action_space_size[0]]
84-
)
82+
feed_dict[self.model.true_action] = mini_batch["actions"]
8583
else:
86-
feed_dict[self.model.true_action] = mini_batch["actions"].reshape(
87-
[-1, len(self.brain.vector_action_space_size)]
88-
)
84+
feed_dict[self.model.true_action] = mini_batch["actions"]
8985
feed_dict[self.model.action_masks] = np.ones(
9086
(num_sequences, sum(self.brain.vector_action_space_size))
9187
)
9288
if self.use_vec_obs:
93-
apparent_obs_size = (
94-
self.brain.vector_observation_space_size
95-
* self.brain.num_stacked_vector_observations
96-
)
97-
feed_dict[self.model.vector_in] = mini_batch["vector_obs"].reshape(
98-
[-1, apparent_obs_size]
99-
)
89+
feed_dict[self.model.vector_in] = mini_batch["vector_obs"]
10090
for i, _ in enumerate(self.model.visual_in):
10191
visual_obs = mini_batch["visual_obs%d" % i]
10292
feed_dict[self.model.visual_in[i]] = visual_obs

ml-agents/mlagents/trainers/bc/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def update_policy(self):
124124
"""
125125
Updates the policy.
126126
"""
127-
self.demonstration_buffer.update_buffer.shuffle()
127+
self.demonstration_buffer.update_buffer.shuffle(self.policy.sequence_length)
128128
batch_losses = []
129129
num_batches = min(
130130
len(self.demonstration_buffer.update_buffer["actions"]) // self.n_sequences,

0 commit comments

Comments
 (0)