11import numpy as np
22from deepmd .env import tf
3- from deepmd .common import ClassArg , get_activation_func , get_precision_func
3+ from deepmd .common import ClassArg , get_activation_func , get_precision
44from deepmd .RunOptions import global_tf_float_precision
55from deepmd .RunOptions import global_np_float_precision
66from deepmd .env import op_module
@@ -20,7 +20,7 @@ def __init__ (self, jdata):
2020 .add ('exclude_types' , list , default = []) \
2121 .add ('set_davg_zero' , bool , default = False ) \
2222 .add ('activation_function' , str , default = 'tanh' ) \
23- .add ('precision' , int , default = 0 )
23+ .add ('precision' , str , default = "default" )
2424 class_data = args .parse (jdata )
2525 self .sel_a = class_data ['sel' ]
2626 self .rcut_r = class_data ['rcut' ]
@@ -31,7 +31,7 @@ def __init__ (self, jdata):
3131 self .seed = class_data ['seed' ]
3232 self .trainable = class_data ['trainable' ]
3333 self .filter_activation_fn = get_activation_func (class_data ['activation_function' ])
34- self .filter_precision = get_precision_func (class_data ['precision' ])
34+ self .filter_precision = get_precision (class_data ['precision' ])
3535 exclude_types = class_data ['exclude_types' ]
3636 self .exclude_types = set ()
3737 for tt in exclude_types :
0 commit comments