22import numpy as np
33
44from deepmd .env import tf
5- from deepmd .common import ClassArg , add_data_requirement , get_activation_func , get_precision_func
5+ from deepmd .common import ClassArg , add_data_requirement , get_activation_func , get_precision
66from deepmd .Network import one_layer
77from deepmd .DescrptLocFrame import DescrptLocFrame
88from deepmd .DescrptSeA import DescrptSeA
@@ -23,7 +23,7 @@ def __init__ (self, jdata, descrpt):
2323 .add ('seed' , int ) \
2424 .add ('atom_ener' , list , default = [])\
2525 .add ("activation_function" , str , default = "tanh" )\
26- .add ("precision" , int , default = 0 )
26+ .add ("precision" , str , default = "default" )
2727 class_data = args .parse (jdata )
2828 self .numb_fparam = class_data ['numb_fparam' ]
2929 self .numb_aparam = class_data ['numb_aparam' ]
@@ -32,7 +32,7 @@ def __init__ (self, jdata, descrpt):
3232 self .rcond = class_data ['rcond' ]
3333 self .seed = class_data ['seed' ]
3434 self .fitting_activation_fn = get_activation_func (class_data ["activation_function" ])
35- self .fitting_precision = get_precision_func (class_data ['precision' ])
35+ self .fitting_precision = get_precision (class_data ['precision' ])
3636 self .atom_ener = []
3737 for at , ae in enumerate (class_data ['atom_ener' ]):
3838 if ae is not None :
@@ -249,15 +249,15 @@ def __init__ (self, jdata, descrpt) :
249249 .add ('sel_type' , [list ,int ], default = [ii for ii in range (self .ntypes )], alias = 'wfc_type' )\
250250 .add ('seed' , int )\
251251 .add ("activation_function" , str , default = "tanh" )\
252- .add ('precision' , int , default = 0 )
252+ .add ('precision' , str , default = "default" )
253253 class_data = args .parse (jdata )
254254 self .n_neuron = class_data ['neuron' ]
255255 self .resnet_dt = class_data ['resnet_dt' ]
256256 self .wfc_numb = class_data ['wfc_numb' ]
257257 self .sel_type = class_data ['sel_type' ]
258258 self .seed = class_data ['seed' ]
259259 self .fitting_activation_fn = get_activation_func (class_data ["activation_function" ])
260- self .fitting_precision = get_precision_func (class_data ['precision' ])
260+ self .fitting_precision = get_precision (class_data ['precision' ])
261261 self .useBN = False
262262
263263
@@ -332,14 +332,14 @@ def __init__ (self, jdata, descrpt) :
332332 .add ('sel_type' , [list ,int ], default = [ii for ii in range (self .ntypes )], alias = 'pol_type' )\
333333 .add ('seed' , int )\
334334 .add ("activation_function" , str , default = "tanh" )\
335- .add ('precision' , int , default = 0 )
335+ .add ('precision' , str , default = "default" )
336336 class_data = args .parse (jdata )
337337 self .n_neuron = class_data ['neuron' ]
338338 self .resnet_dt = class_data ['resnet_dt' ]
339339 self .sel_type = class_data ['sel_type' ]
340340 self .seed = class_data ['seed' ]
341341 self .fitting_activation_fn = get_activation_func (class_data ["activation_function" ])
342- self .fitting_precision = get_precision_func (class_data ['precision' ])
342+ self .fitting_precision = get_precision (class_data ['precision' ])
343343 self .useBN = False
344344
345345 def get_sel_type (self ):
@@ -416,7 +416,7 @@ def __init__ (self, jdata, descrpt) :
416416 .add ('sel_type' , [list ,int ], default = [ii for ii in range (self .ntypes )], alias = 'pol_type' )\
417417 .add ('seed' , int )\
418418 .add ("activation_function" , str , default = "tanh" )\
419- .add ('precision' , int , default = 0 )
419+ .add ('precision' , str , default = "default" )
420420 class_data = args .parse (jdata )
421421 self .n_neuron = class_data ['neuron' ]
422422 self .resnet_dt = class_data ['resnet_dt' ]
@@ -426,7 +426,7 @@ def __init__ (self, jdata, descrpt) :
426426 self .diag_shift = class_data ['diag_shift' ]
427427 self .scale = class_data ['scale' ]
428428 self .fitting_activation_fn = get_activation_func (class_data ["activation_function" ])
429- self .fitting_precision = get_precision_func (class_data ['precision' ])
429+ self .fitting_precision = get_precision (class_data ['precision' ])
430430 if type (self .sel_type ) is not list :
431431 self .sel_type = [self .sel_type ]
432432 if type (self .diag_shift ) is not list :
@@ -573,14 +573,14 @@ def __init__ (self, jdata, descrpt) :
573573 .add ('sel_type' , [list ,int ], default = [ii for ii in range (self .ntypes )], alias = 'dipole_type' )\
574574 .add ('seed' , int )\
575575 .add ("activation_function" , str , default = "tanh" )\
576- .add ('precision' , int , default = 0 )
576+ .add ('precision' , str , default = "default" )
577577 class_data = args .parse (jdata )
578578 self .n_neuron = class_data ['neuron' ]
579579 self .resnet_dt = class_data ['resnet_dt' ]
580580 self .sel_type = class_data ['sel_type' ]
581581 self .seed = class_data ['seed' ]
582582 self .fitting_activation_fn = get_activation_func (class_data ["activation_function" ])
583- self .fitting_precision = get_precision_func (class_data ['precision' ])
583+ self .fitting_precision = get_precision (class_data ['precision' ])
584584 self .dim_rot_mat_1 = descrpt .get_dim_rot_mat_1 ()
585585 self .dim_rot_mat = self .dim_rot_mat_1 * 3
586586 self .useBN = False
0 commit comments