Skip to content

Commit edd4155

Browse files
authored
Update DescrptSeA.py
1 parent 1dae1bc commit edd4155

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

source/train/DescrptSeA.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from deepmd.env import tf
3-
from deepmd.common import ClassArg, get_activation_func
3+
from deepmd.common import ClassArg, get_activation_funcget_precision_func
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
@@ -19,7 +19,8 @@ def __init__ (self, jdata):
1919
.add('seed', int) \
2020
.add('exclude_types', list, default = []) \
2121
.add('set_davg_zero', bool, default = False) \
22-
.add('activation_function', str, default = 'tanh')
22+
.add('activation_function', str, default = 'tanh') \
23+
.add('precision', int, default = 0)
2324
class_data = args.parse(jdata)
2425
self.sel_a = class_data['sel']
2526
self.rcut_r = class_data['rcut']
@@ -30,6 +31,7 @@ def __init__ (self, jdata):
3031
self.seed = class_data['seed']
3132
self.trainable = class_data['trainable']
3233
self.filter_activation_fn = get_activation_func(class_data['activation_function'])
34+
self.filter_precision = get_precision_func(class_data['precision'])
3335
exclude_types = class_data['exclude_types']
3436
self.exclude_types = set()
3537
for tt in exclude_types:
@@ -322,6 +324,7 @@ def _filter(self,
322324
seed=None,
323325
trainable = True):
324326
# natom x (nei x 4)
327+
inputs = tf.cast(inputs, self.filter_precision)
325328
shape = inputs.get_shape().as_list()
326329
outputs_size = [1] + self.filter_neuron
327330
outputs_size_2 = self.n_axis_neuron
@@ -343,18 +346,18 @@ def _filter(self,
343346
for ii in range(1, len(outputs_size)):
344347
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
345348
[outputs_size[ii - 1], outputs_size[ii]],
346-
global_tf_float_precision,
349+
self.filter_precision,
347350
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
348351
trainable = trainable)
349352
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
350353
[1, outputs_size[ii]],
351-
global_tf_float_precision,
354+
self.filter_precision,
352355
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
353356
trainable = trainable)
354357
if self.filter_resnet_dt :
355358
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
356359
[1, outputs_size[ii]],
357-
global_tf_float_precision,
360+
self.filter_precision,
358361
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
359362
trainable = trainable)
360363
if outputs_size[ii] == outputs_size[ii-1]:
@@ -409,6 +412,7 @@ def _filter_type_ext(self,
409412
seed=None,
410413
trainable = True):
411414
# natom x (nei x 4)
415+
inputs = tf.cast(inputs, self.filter_precision)
412416
outputs_size = [1] + self.filter_neuron
413417
outputs_size_2 = self.n_axis_neuron
414418
with tf.variable_scope(name, reuse=reuse):
@@ -430,18 +434,18 @@ def _filter_type_ext(self,
430434
for ii in range(1, len(outputs_size)):
431435
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
432436
[outputs_size[ii - 1], outputs_size[ii]],
433-
global_tf_float_precision,
437+
self.filter_precision,
434438
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
435439
trainable = trainable)
436440
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
437441
[1, outputs_size[ii]],
438-
global_tf_float_precision,
442+
self.filter_precision,
439443
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
440444
trainable = trainable)
441445
if self.filter_resnet_dt :
442446
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
443447
[1, outputs_size[ii]],
444-
global_tf_float_precision,
448+
self.filter_precision,
445449
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
446450
trainable = trainable)
447451
if outputs_size[ii] == outputs_size[ii-1]:

0 commit comments

Comments
 (0)