Skip to content

Commit 1f39ddf

Browse files
authored
Fix for Discrete observations + Curiosity (#866)
1 parent 3b06c61 commit 1f39ddf

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

python/unitytrainers/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def create_discrete_observation_encoder(observation_input, s_size, h_size, activ
150150
:param num_layers: number of hidden layers to create.
151151
:return: List of hidden layer tensors.
152152
"""
153-
with tf.name_scope(scope):
153+
with tf.variable_scope(scope):
154154
vector_in = tf.reshape(observation_input, [-1])
155155
state_onehot = tf.one_hot(vector_in, s_size)
156156
hidden = state_onehot

python/unitytrainers/ppo/models.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,34 @@ def create_curiosity_encoders(self):
9191
encoded_next_state_list.append(hidden_next_visual)
9292

9393
if self.o_size > 0:
94-
# Create input op for next (t+1) vector observation.
95-
self.next_vector_in = tf.placeholder(shape=[None, self.o_size], dtype=tf.float32,
96-
name='next_vector_observation')
9794

9895
# Create the encoder ops for current and next vector input. Not that these encoders are siamese.
99-
encoded_vector_obs = self.create_continuous_observation_encoder(self.vector_in,
100-
self.curiosity_enc_size,
101-
self.swish, 2, "vector_obs_encoder", False)
102-
encoded_next_vector_obs = self.create_continuous_observation_encoder(self.next_vector_in,
103-
self.curiosity_enc_size,
104-
self.swish, 2, "vector_obs_encoder",
105-
True)
96+
if self.brain.vector_observation_space_type == "continuous":
97+
# Create input op for next (t+1) vector observation.
98+
self.next_vector_in = tf.placeholder(shape=[None, self.o_size], dtype=tf.float32,
99+
name='next_vector_observation')
100+
101+
encoded_vector_obs = self.create_continuous_observation_encoder(self.vector_in,
102+
self.curiosity_enc_size,
103+
self.swish, 2, "vector_obs_encoder",
104+
False)
105+
encoded_next_vector_obs = self.create_continuous_observation_encoder(self.next_vector_in,
106+
self.curiosity_enc_size,
107+
self.swish, 2,
108+
"vector_obs_encoder",
109+
True)
110+
else:
111+
self.next_vector_in = tf.placeholder(shape=[None, 1], dtype=tf.int32,
112+
name='next_vector_observation')
113+
114+
encoded_vector_obs = self.create_discrete_observation_encoder(self.vector_in, self.o_size,
115+
self.curiosity_enc_size,
116+
self.swish, 2, "vector_obs_encoder",
117+
False)
118+
encoded_next_vector_obs = self.create_discrete_observation_encoder(self.next_vector_in, self.o_size,
119+
self.curiosity_enc_size,
120+
self.swish, 2, "vector_obs_encoder",
121+
True)
106122
encoded_state_list.append(encoded_vector_obs)
107123
encoded_next_state_list.append(encoded_next_vector_obs)
108124

python/unitytrainers/ppo/trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,9 @@ def update_model(self):
432432
else:
433433
feed_dict[self.model.vector_in] = np.array(buffer['vector_obs'][start:end]).reshape(
434434
[-1, self.brain.num_stacked_vector_observations])
435+
if self.use_curiosity:
436+
feed_dict[self.model.next_vector_in] = np.array(buffer['next_vector_in'][start:end]) \
437+
.reshape([-1, self.brain.num_stacked_vector_observations])
435438
if self.use_visual_obs:
436439
for i, _ in enumerate(self.model.visual_in):
437440
_obs = np.array(buffer['visual_obs%d' % i][start:end])

0 commit comments

Comments
 (0)