Skip to content

Commit 7165038

Browse files
authored
Update DescrptSeR.py
1 parent edd4155 commit 7165038

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

source/train/DescrptSeR.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from deepmd.env import tf
3-
from deepmd.common import ClassArg, get_activation_func
3+
from deepmd.common import ClassArg, get_activation_func, get_precision_func
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
@@ -18,7 +18,8 @@ def __init__ (self, jdata):
1818
.add('seed', int) \
1919
.add('exclude_types', list, default = []) \
2020
.add('set_davg_zero', bool, default = False) \
21-
.add("activation_function", str, default = "tanh")
21+
.add("activation_function", str, default = "tanh") \
22+
.add("precision", int, default = 0)
2223
class_data = args.parse(jdata)
2324
self.sel_r = class_data['sel']
2425
self.rcut = class_data['rcut']
@@ -27,7 +28,8 @@ def __init__ (self, jdata):
2728
self.filter_resnet_dt = class_data['resnet_dt']
2829
self.seed = class_data['seed']
2930
self.trainable = class_data['trainable']
30-
self.filter_activation_fn = get_activation_func(class_data["activation_function"])
31+
self.filter_activation_fn = get_activation_func(class_data["activation_function"])
32+
self.filter_precision = get_precision_func(class_data['precision'])
3133
exclude_types = class_data['exclude_types']
3234
self.exclude_types = set()
3335
for tt in exclude_types:
@@ -206,7 +208,7 @@ def _pass_filter(self,
206208
[ 0, start_index* self.ndescrpt],
207209
[-1, natoms[2+type_i]* self.ndescrpt] )
208210
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
209-
layer = self._filter_r(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
211+
layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
210212
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
211213
output.append(layer)
212214
start_index += natoms[2+type_i]
@@ -288,18 +290,18 @@ def _filter_r(self,
288290
for ii in range(1, len(outputs_size)):
289291
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
290292
[outputs_size[ii - 1], outputs_size[ii]],
291-
global_tf_float_precision,
293+
self.filter_precision,
292294
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
293295
trainable = trainable)
294296
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
295297
[1, outputs_size[ii]],
296-
global_tf_float_precision,
298+
self.filter_precision,
297299
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
298300
trainable = trainable)
299301
if self.filter_resnet_dt :
300302
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
301303
[1, outputs_size[ii]],
302-
global_tf_float_precision,
304+
self.filter_precision,
303305
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
304306
trainable = trainable)
305307
if outputs_size[ii] == outputs_size[ii-1]:

0 commit comments

Comments
 (0)