Skip to content

Commit 0c4ea37

Browse files
authored
Update DescrptSeA.py
1 parent 7165038 commit 0c4ea37

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

source/train/DescrptSeA.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _pass_filter(self,
249249
[ 0, start_index* self.ndescrpt],
250250
[-1, natoms[2+type_i]* self.ndescrpt] )
251251
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
252-
layer, qmat = self._filter(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
252+
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
253253
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
254254
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3])
255255
output.append(layer)
@@ -324,7 +324,6 @@ def _filter(self,
324324
seed=None,
325325
trainable = True):
326326
# natom x (nei x 4)
327-
inputs = tf.cast(inputs, self.filter_precision)
328327
shape = inputs.get_shape().as_list()
329328
outputs_size = [1] + self.filter_neuron
330329
outputs_size_2 = self.n_axis_neuron
@@ -412,7 +411,6 @@ def _filter_type_ext(self,
412411
seed=None,
413412
trainable = True):
414413
# natom x (nei x 4)
415-
inputs = tf.cast(inputs, self.filter_precision)
416414
outputs_size = [1] + self.filter_neuron
417415
outputs_size_2 = self.n_axis_neuron
418416
with tf.variable_scope(name, reuse=reuse):

0 commit comments

Comments
 (0)