11import numpy as np
22from deepmd .env import tf
3- from deepmd .common import ClassArg , get_activation_func , get_precision_func
3+ from deepmd .common import ClassArg , get_activation_func , get_precision
44from deepmd .RunOptions import global_tf_float_precision
55from deepmd .RunOptions import global_np_float_precision
66from deepmd .env import op_module
@@ -19,7 +19,7 @@ def __init__ (self, jdata):
1919 .add ('exclude_types' , list , default = []) \
2020 .add ('set_davg_zero' , bool , default = False ) \
2121 .add ("activation_function" , str , default = "tanh" ) \
22- .add ("precision" , int , default = 0 )
22+ .add ("precision" , str , default = "default" )
2323 class_data = args .parse (jdata )
2424 self .sel_r = class_data ['sel' ]
2525 self .rcut = class_data ['rcut' ]
@@ -29,7 +29,7 @@ def __init__ (self, jdata):
2929 self .seed = class_data ['seed' ]
3030 self .trainable = class_data ['trainable' ]
3131 self .filter_activation_fn = get_activation_func (class_data ["activation_function" ])
32- self .filter_precision = get_precision_func (class_data ['precision' ])
32+ self .filter_precision = get_precision (class_data ['precision' ])
3333 exclude_types = class_data ['exclude_types' ]
3434 self .exclude_types = set ()
3535 for tt in exclude_types :
0 commit comments