11import numpy as np
22from deepmd .env import tf
3- from deepmd .common import ClassArg , get_activation_func
3+ from deepmd .common import ClassArg , get_activation_func , get_precision
44from deepmd .RunOptions import global_tf_float_precision
55from deepmd .RunOptions import global_np_float_precision
66from deepmd .env import op_module
@@ -19,7 +19,8 @@ def __init__ (self, jdata):
1919 .add ('seed' , int ) \
2020 .add ('exclude_types' , list , default = []) \
2121 .add ('set_davg_zero' , bool , default = False ) \
22- .add ('activation_function' , str , default = 'tanh' )
22+ .add ('activation_function' , str , default = 'tanh' ) \
23+ .add ('precision' , str , default = "default" )
2324 class_data = args .parse (jdata )
2425 self .sel_a = class_data ['sel' ]
2526 self .rcut_r = class_data ['rcut' ]
@@ -30,6 +31,7 @@ def __init__ (self, jdata):
3031 self .seed = class_data ['seed' ]
3132 self .trainable = class_data ['trainable' ]
3233 self .filter_activation_fn = get_activation_func (class_data ['activation_function' ])
34+ self .filter_precision = get_precision (class_data ['precision' ])
3335 exclude_types = class_data ['exclude_types' ]
3436 self .exclude_types = set ()
3537 for tt in exclude_types :
@@ -247,7 +249,7 @@ def _pass_filter(self,
247249 [ 0 , start_index * self .ndescrpt ],
248250 [- 1 , natoms [2 + type_i ]* self .ndescrpt ] )
249251 inputs_i = tf .reshape (inputs_i , [- 1 , self .ndescrpt ])
250- layer , qmat = self ._filter (inputs_i , type_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable , activation_fn = self .filter_activation_fn )
252+ layer , qmat = self ._filter (tf . cast ( inputs_i , self . filter_precision ) , type_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable , activation_fn = self .filter_activation_fn )
251253 layer = tf .reshape (layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_out ()])
252254 qmat = tf .reshape (qmat , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_rot_mat_1 () * 3 ])
253255 output .append (layer )
@@ -343,18 +345,18 @@ def _filter(self,
343345 for ii in range (1 , len (outputs_size )):
344346 w = tf .get_variable ('matrix_' + str (ii )+ '_' + str (type_i ),
345347 [outputs_size [ii - 1 ], outputs_size [ii ]],
346- global_tf_float_precision ,
348+ self . filter_precision ,
347349 tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ),
348350 trainable = trainable )
349351 b = tf .get_variable ('bias_' + str (ii )+ '_' + str (type_i ),
350352 [1 , outputs_size [ii ]],
351- global_tf_float_precision ,
353+ self . filter_precision ,
352354 tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
353355 trainable = trainable )
354356 if self .filter_resnet_dt :
355357 idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
356358 [1 , outputs_size [ii ]],
357- global_tf_float_precision ,
359+ self . filter_precision ,
358360 tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ),
359361 trainable = trainable )
360362 if outputs_size [ii ] == outputs_size [ii - 1 ]:
@@ -430,18 +432,18 @@ def _filter_type_ext(self,
430432 for ii in range (1 , len (outputs_size )):
431433 w = tf .get_variable ('matrix_' + str (ii )+ '_' + str (type_i ),
432434 [outputs_size [ii - 1 ], outputs_size [ii ]],
433- global_tf_float_precision ,
435+ self . filter_precision ,
434436 tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ),
435437 trainable = trainable )
436438 b = tf .get_variable ('bias_' + str (ii )+ '_' + str (type_i ),
437439 [1 , outputs_size [ii ]],
438- global_tf_float_precision ,
440+ self . filter_precision ,
439441 tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
440442 trainable = trainable )
441443 if self .filter_resnet_dt :
442444 idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
443445 [1 , outputs_size [ii ]],
444- global_tf_float_precision ,
446+ self . filter_precision ,
445447 tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ),
446448 trainable = trainable )
447449 if outputs_size [ii ] == outputs_size [ii - 1 ]:
0 commit comments