@@ -249,7 +249,7 @@ def _pass_filter(self,
249249 [ 0 , start_index * self .ndescrpt ],
250250 [- 1 , natoms [2 + type_i ]* self .ndescrpt ] )
251251 inputs_i = tf .reshape (inputs_i , [- 1 , self .ndescrpt ])
252- layer , qmat = self ._filter (inputs_i , type_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable , activation_fn = self .filter_activation_fn )
252+ layer , qmat = self ._filter (tf . cast ( inputs_i , self . filter_precision ) , type_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable , activation_fn = self .filter_activation_fn )
253253 layer = tf .reshape (layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_out ()])
254254 qmat = tf .reshape (qmat , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_rot_mat_1 () * 3 ])
255255 output .append (layer )
@@ -324,7 +324,6 @@ def _filter(self,
324324 seed = None ,
325325 trainable = True ):
326326 # natom x (nei x 4)
327- inputs = tf .cast (inputs , self .filter_precision )
328327 shape = inputs .get_shape ().as_list ()
329328 outputs_size = [1 ] + self .filter_neuron
330329 outputs_size_2 = self .n_axis_neuron
@@ -412,7 +411,6 @@ def _filter_type_ext(self,
412411 seed = None ,
413412 trainable = True ):
414413 # natom x (nei x 4)
415- inputs = tf .cast (inputs , self .filter_precision )
416414 outputs_size = [1 ] + self .filter_neuron
417415 outputs_size_2 = self .n_axis_neuron
418416 with tf .variable_scope (name , reuse = reuse ):
0 commit comments