11import numpy as np
22from deepmd .env import tf
3- from deepmd .common import ClassArg , get_activation_func
3+ from deepmd .common import ClassArg , get_activation_func , get_precision_func
44from deepmd .RunOptions import global_tf_float_precision
55from deepmd .RunOptions import global_np_float_precision
66from deepmd .env import op_module
@@ -18,7 +18,8 @@ def __init__ (self, jdata):
1818 .add ('seed' , int ) \
1919 .add ('exclude_types' , list , default = []) \
2020 .add ('set_davg_zero' , bool , default = False ) \
21- .add ("activation_function" , str , default = "tanh" )
21+ .add ("activation_function" , str , default = "tanh" ) \
22+ .add ("precision" , int , default = 0 )
2223 class_data = args .parse (jdata )
2324 self .sel_r = class_data ['sel' ]
2425 self .rcut = class_data ['rcut' ]
@@ -27,7 +28,8 @@ def __init__ (self, jdata):
2728 self .filter_resnet_dt = class_data ['resnet_dt' ]
2829 self .seed = class_data ['seed' ]
2930 self .trainable = class_data ['trainable' ]
30- self .filter_activation_fn = get_activation_func (class_data ["activation_function" ])
31+ self .filter_activation_fn = get_activation_func (class_data ["activation_function" ])
32+ self .filter_precision = get_precision_func (class_data ['precision' ])
3133 exclude_types = class_data ['exclude_types' ]
3234 self .exclude_types = set ()
3335 for tt in exclude_types :
@@ -206,7 +208,7 @@ def _pass_filter(self,
206208 [ 0 , start_index * self .ndescrpt ],
207209 [- 1 , natoms [2 + type_i ]* self .ndescrpt ] )
208210 inputs_i = tf .reshape (inputs_i , [- 1 , self .ndescrpt ])
209- layer = self ._filter_r (inputs_i , type_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable , activation_fn = self .filter_activation_fn )
211+ layer = self ._filter_r (tf . cast ( inputs_i , self . filter_precision ) , type_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable , activation_fn = self .filter_activation_fn )
210212 layer = tf .reshape (layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_out ()])
211213 output .append (layer )
212214 start_index += natoms [2 + type_i ]
@@ -288,18 +290,18 @@ def _filter_r(self,
288290 for ii in range (1 , len (outputs_size )):
289291 w = tf .get_variable ('matrix_' + str (ii )+ '_' + str (type_i ),
290292 [outputs_size [ii - 1 ], outputs_size [ii ]],
291- global_tf_float_precision ,
293+ self . filter_precision ,
292294 tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ),
293295 trainable = trainable )
294296 b = tf .get_variable ('bias_' + str (ii )+ '_' + str (type_i ),
295297 [1 , outputs_size [ii ]],
296- global_tf_float_precision ,
298+ self . filter_precision ,
297299 tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
298300 trainable = trainable )
299301 if self .filter_resnet_dt :
300302 idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
301303 [1 , outputs_size [ii ]],
302- global_tf_float_precision ,
304+ self . filter_precision ,
303305 tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ),
304306 trainable = trainable )
305307 if outputs_size [ii ] == outputs_size [ii - 1 ]:
0 commit comments