11import numpy as np
22from deepmd .env import tf
3- from deepmd .common import ClassArg , get_activation_func
3+ from deepmd .common import ClassArg , get_activation_func , get_precision_func
44from deepmd .RunOptions import global_tf_float_precision
55from deepmd .RunOptions import global_np_float_precision
66from deepmd .env import op_module
@@ -19,7 +19,8 @@ def __init__ (self, jdata):
1919 .add ('seed' , int ) \
2020 .add ('exclude_types' , list , default = []) \
2121 .add ('set_davg_zero' , bool , default = False ) \
22- .add ('activation_function' , str , default = 'tanh' )
22+ .add ('activation_function' , str , default = 'tanh' ) \
23+ .add ('precision' , int , default = 0 )
2324 class_data = args .parse (jdata )
2425 self .sel_a = class_data ['sel' ]
2526 self .rcut_r = class_data ['rcut' ]
@@ -30,6 +31,7 @@ def __init__ (self, jdata):
3031 self .seed = class_data ['seed' ]
3132 self .trainable = class_data ['trainable' ]
3233 self .filter_activation_fn = get_activation_func (class_data ['activation_function' ])
34+ self .filter_precision = get_precision_func (class_data ['precision' ])
3335 exclude_types = class_data ['exclude_types' ]
3436 self .exclude_types = set ()
3537 for tt in exclude_types :
@@ -322,6 +324,7 @@ def _filter(self,
322324 seed = None ,
323325 trainable = True ):
324326 # natom x (nei x 4)
327+ inputs = tf .cast (inputs , self .filter_precision )
325328 shape = inputs .get_shape ().as_list ()
326329 outputs_size = [1 ] + self .filter_neuron
327330 outputs_size_2 = self .n_axis_neuron
@@ -343,18 +346,18 @@ def _filter(self,
343346 for ii in range (1 , len (outputs_size )):
344347 w = tf .get_variable ('matrix_' + str (ii )+ '_' + str (type_i ),
345348 [outputs_size [ii - 1 ], outputs_size [ii ]],
346- global_tf_float_precision ,
349+ self . filter_precision ,
347350 tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ),
348351 trainable = trainable )
349352 b = tf .get_variable ('bias_' + str (ii )+ '_' + str (type_i ),
350353 [1 , outputs_size [ii ]],
351- global_tf_float_precision ,
354+ self . filter_precision ,
352355 tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
353356 trainable = trainable )
354357 if self .filter_resnet_dt :
355358 idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
356359 [1 , outputs_size [ii ]],
357- global_tf_float_precision ,
360+ self . filter_precision ,
358361 tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ),
359362 trainable = trainable )
360363 if outputs_size [ii ] == outputs_size [ii - 1 ]:
@@ -409,6 +412,7 @@ def _filter_type_ext(self,
409412 seed = None ,
410413 trainable = True ):
411414 # natom x (nei x 4)
415+ inputs = tf .cast (inputs , self .filter_precision )
412416 outputs_size = [1 ] + self .filter_neuron
413417 outputs_size_2 = self .n_axis_neuron
414418 with tf .variable_scope (name , reuse = reuse ):
@@ -430,18 +434,18 @@ def _filter_type_ext(self,
430434 for ii in range (1 , len (outputs_size )):
431435 w = tf .get_variable ('matrix_' + str (ii )+ '_' + str (type_i ),
432436 [outputs_size [ii - 1 ], outputs_size [ii ]],
433- global_tf_float_precision ,
437+ self . filter_precision ,
434438 tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ),
435439 trainable = trainable )
436440 b = tf .get_variable ('bias_' + str (ii )+ '_' + str (type_i ),
437441 [1 , outputs_size [ii ]],
438- global_tf_float_precision ,
442+ self . filter_precision ,
439443 tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
440444 trainable = trainable )
441445 if self .filter_resnet_dt :
442446 idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
443447 [1 , outputs_size [ii ]],
444- global_tf_float_precision ,
448+ self . filter_precision ,
445449 tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ),
446450 trainable = trainable )
447451 if outputs_size [ii ] == outputs_size [ii - 1 ]:
0 commit comments