Skip to content

Commit e4d43a0

Browse files
author
Ervin T
authored
Remove unnecessary feed_dicts for GAIL and Curiosity (#2348)
1 parent c18f55c commit e4d43a0

File tree

2 files changed

+1
-25
lines changed

2 files changed

+1
-25
lines changed

ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,6 @@ def evaluate(
7070
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i]
7171
if self.policy.use_vec_obs:
7272
feed_dict[self.model.next_vector_in] = next_info.vector_observations
73-
if self.policy.use_recurrent:
74-
if current_info.memories.shape[1] == 0:
75-
current_info.memories = self.policy.make_empty_memory(
76-
len(current_info.agents)
77-
)
78-
feed_dict[self.policy.model.memory_in] = current_info.memories
7973
unscaled_reward = self.policy.sess.run(
8074
self.model.intrinsic_reward, feed_dict=feed_dict
8175
)
@@ -145,20 +139,10 @@ def _update_batch(
145139
feed_dict[self.policy.model.output_pre] = mini_batch["actions_pre"].reshape(
146140
[-1, self.policy.model.act_size[0]]
147141
)
148-
feed_dict[self.policy.model.epsilon] = mini_batch[
149-
"random_normal_epsilon"
150-
].reshape([-1, self.policy.model.act_size[0]])
151142
else:
152143
feed_dict[self.policy.model.action_holder] = mini_batch["actions"].reshape(
153144
[-1, len(self.policy.model.act_size)]
154145
)
155-
if self.policy.use_recurrent:
156-
feed_dict[self.policy.model.prev_action] = mini_batch[
157-
"prev_action"
158-
].reshape([-1, len(self.policy.model.act_size)])
159-
feed_dict[self.policy.model.action_masks] = mini_batch[
160-
"action_mask"
161-
].reshape([-1, sum(self.policy.brain.vector_action_space_size)])
162146
if self.policy.use_vec_obs:
163147
feed_dict[self.policy.model.vector_in] = mini_batch["vector_obs"].reshape(
164148
[-1, self.policy.vec_obs_size]
@@ -185,9 +169,7 @@ def _update_batch(
185169
)
186170
else:
187171
feed_dict[self.model.next_visual_in[i]] = _obs
188-
if self.policy.use_recurrent:
189-
mem_in = mini_batch["memory"][:, 0, :]
190-
feed_dict[self.policy.model.memory_in] = mem_in
172+
191173
self.has_updated = True
192174
run_out = self.policy._execute_model(feed_dict, self.update_dict)
193175
return run_out

ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,6 @@ def evaluate(
7575
feed_dict[
7676
self.policy.model.action_holder
7777
] = next_info.previous_vector_actions
78-
if self.policy.use_recurrent:
79-
if current_info.memories.shape[1] == 0:
80-
current_info.memories = self.policy.make_empty_memory(
81-
len(current_info.agents)
82-
)
83-
feed_dict[self.policy.model.memory_in] = current_info.memories
8478
unscaled_reward = self.policy.sess.run(
8579
self.model.intrinsic_reward, feed_dict=feed_dict
8680
)

0 commit comments

Comments
 (0)