Skip to content

Commit 4f0f54d

Browse files
authored
Update Fitting.py
1 parent 03bc29a commit 4f0f54d

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

source/train/Fitting.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33

44
from deepmd.env import tf
5-
from deepmd.common import ClassArg, add_data_requirement, get_activation_func, get_precision_func
5+
from deepmd.common import ClassArg, add_data_requirement, get_activation_func, get_precision
66
from deepmd.Network import one_layer
77
from deepmd.DescrptLocFrame import DescrptLocFrame
88
from deepmd.DescrptSeA import DescrptSeA
@@ -23,7 +23,7 @@ def __init__ (self, jdata, descrpt):
2323
.add('seed', int) \
2424
.add('atom_ener', list, default = [])\
2525
.add("activation_function", str, default = "tanh")\
26-
.add("precision", int, default = 0)
26+
.add("precision", str, default = "default")
2727
class_data = args.parse(jdata)
2828
self.numb_fparam = class_data['numb_fparam']
2929
self.numb_aparam = class_data['numb_aparam']
@@ -32,7 +32,7 @@ def __init__ (self, jdata, descrpt):
3232
self.rcond = class_data['rcond']
3333
self.seed = class_data['seed']
3434
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
35-
self.fitting_precision = get_precision_func(class_data['precision'])
35+
self.fitting_precision = get_precision(class_data['precision'])
3636
self.atom_ener = []
3737
for at, ae in enumerate(class_data['atom_ener']):
3838
if ae is not None:
@@ -249,15 +249,15 @@ def __init__ (self, jdata, descrpt) :
249249
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'wfc_type')\
250250
.add('seed', int)\
251251
.add("activation_function", str, default = "tanh")\
252-
.add('precision', int, default = 0)
252+
.add('precision', str, default = "default")
253253
class_data = args.parse(jdata)
254254
self.n_neuron = class_data['neuron']
255255
self.resnet_dt = class_data['resnet_dt']
256256
self.wfc_numb = class_data['wfc_numb']
257257
self.sel_type = class_data['sel_type']
258258
self.seed = class_data['seed']
259259
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
260-
self.fitting_precision = get_precision_func(class_data['precision'])
260+
self.fitting_precision = get_precision(class_data['precision'])
261261
self.useBN = False
262262

263263

@@ -332,14 +332,14 @@ def __init__ (self, jdata, descrpt) :
332332
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\
333333
.add('seed', int)\
334334
.add("activation_function", str, default = "tanh")\
335-
.add('precision', int, default = 0)
335+
.add('precision', str, default = "default")
336336
class_data = args.parse(jdata)
337337
self.n_neuron = class_data['neuron']
338338
self.resnet_dt = class_data['resnet_dt']
339339
self.sel_type = class_data['sel_type']
340340
self.seed = class_data['seed']
341341
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
342-
self.fitting_precision = get_precision_func(class_data['precision'])
342+
self.fitting_precision = get_precision(class_data['precision'])
343343
self.useBN = False
344344

345345
def get_sel_type(self):
@@ -416,7 +416,7 @@ def __init__ (self, jdata, descrpt) :
416416
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\
417417
.add('seed', int)\
418418
.add("activation_function", str , default = "tanh")\
419-
.add('precision', int, default = 0)
419+
.add('precision', str, default = "default")
420420
class_data = args.parse(jdata)
421421
self.n_neuron = class_data['neuron']
422422
self.resnet_dt = class_data['resnet_dt']
@@ -426,7 +426,7 @@ def __init__ (self, jdata, descrpt) :
426426
self.diag_shift = class_data['diag_shift']
427427
self.scale = class_data['scale']
428428
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
429-
self.fitting_precision = get_precision_func(class_data['precision'])
429+
self.fitting_precision = get_precision(class_data['precision'])
430430
if type(self.sel_type) is not list:
431431
self.sel_type = [self.sel_type]
432432
if type(self.diag_shift) is not list:
@@ -573,14 +573,14 @@ def __init__ (self, jdata, descrpt) :
573573
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'dipole_type')\
574574
.add('seed', int)\
575575
.add("activation_function", str, default = "tanh")\
576-
.add('precision', int, default = 0)
576+
.add('precision', str, default = "default")
577577
class_data = args.parse(jdata)
578578
self.n_neuron = class_data['neuron']
579579
self.resnet_dt = class_data['resnet_dt']
580580
self.sel_type = class_data['sel_type']
581581
self.seed = class_data['seed']
582582
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
583-
self.fitting_precision = get_precision_func(class_data['precision'])
583+
self.fitting_precision = get_precision(class_data['precision'])
584584
self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
585585
self.dim_rot_mat = self.dim_rot_mat_1 * 3
586586
self.useBN = False

0 commit comments

Comments
 (0)