@@ -91,18 +91,34 @@ def create_curiosity_encoders(self):
9191 encoded_next_state_list .append (hidden_next_visual )
9292
9393 if self .o_size > 0 :
94- # Create input op for next (t+1) vector observation.
95- self .next_vector_in = tf .placeholder (shape = [None , self .o_size ], dtype = tf .float32 ,
96- name = 'next_vector_observation' )
9794
9895 # Create the encoder ops for current and next vector input. Not that these encoders are siamese.
99- encoded_vector_obs = self .create_continuous_observation_encoder (self .vector_in ,
100- self .curiosity_enc_size ,
101- self .swish , 2 , "vector_obs_encoder" , False )
102- encoded_next_vector_obs = self .create_continuous_observation_encoder (self .next_vector_in ,
103- self .curiosity_enc_size ,
104- self .swish , 2 , "vector_obs_encoder" ,
105- True )
96+ if self .brain .vector_observation_space_type == "continuous" :
97+ # Create input op for next (t+1) vector observation.
98+ self .next_vector_in = tf .placeholder (shape = [None , self .o_size ], dtype = tf .float32 ,
99+ name = 'next_vector_observation' )
100+
101+ encoded_vector_obs = self .create_continuous_observation_encoder (self .vector_in ,
102+ self .curiosity_enc_size ,
103+ self .swish , 2 , "vector_obs_encoder" ,
104+ False )
105+ encoded_next_vector_obs = self .create_continuous_observation_encoder (self .next_vector_in ,
106+ self .curiosity_enc_size ,
107+ self .swish , 2 ,
108+ "vector_obs_encoder" ,
109+ True )
110+ else :
111+ self .next_vector_in = tf .placeholder (shape = [None , 1 ], dtype = tf .int32 ,
112+ name = 'next_vector_observation' )
113+
114+ encoded_vector_obs = self .create_discrete_observation_encoder (self .vector_in , self .o_size ,
115+ self .curiosity_enc_size ,
116+ self .swish , 2 , "vector_obs_encoder" ,
117+ False )
118+ encoded_next_vector_obs = self .create_discrete_observation_encoder (self .next_vector_in , self .o_size ,
119+ self .curiosity_enc_size ,
120+ self .swish , 2 , "vector_obs_encoder" ,
121+ True )
106122 encoded_state_list .append (encoded_vector_obs )
107123 encoded_next_state_list .append (encoded_next_vector_obs )
108124
0 commit comments