Skip to content

Commit 18600f0

Browse files
author
Han Wang
committed
rename avg -> fparam_avg, inv_std -> fparam_inv_std
1 parent 990ce1b commit 18600f0

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

source/train/Fitting.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,43 @@ def __init__ (self, jdata, descrpt):
2020
self.dim_descrpt = descrpt.get_dim_out()
2121
args = ClassArg()\
2222
.add('numb_fparam', int, default = 0)\
23+
.add('numb_aparam', int, default = 0)\
2324
.add('neuron', list, default = [120,120,120], alias = 'n_neuron')\
2425
.add('resnet_dt', bool, default = True)\
2526
.add('seed', int)
2627
class_data = args.parse(jdata)
2728
self.numb_fparam = class_data['numb_fparam']
29+
self.numb_aparam = class_data['numb_aparam']
2830
self.n_neuron = class_data['neuron']
2931
self.resnet_dt = class_data['resnet_dt']
3032
self.seed = class_data['seed']
3133
self.useBN = False
3234
# data requirement
3335
if self.numb_fparam > 0 :
34-
add_data_requirement('fparam', self.numb_fparam, atomic=False, must=False, high_prec=False)
35-
self.avg = None
36-
self.std = None
37-
self.inv_std = None
36+
add_data_requirement('fparam', self.numb_fparam, atomic=False, must=True, high_prec=False)
37+
self.fparam_avg = None
38+
self.fparam_std = None
39+
self.fparam_inv_std = None
40+
if self.numb_aparam > 0:
41+
add_data_requirement('aparam', self.numb_aparam, atomic=True, must=True, high_prec=False)
42+
self.aparam_avg = None
43+
self.aparam_std = None
44+
self.aparam_inv_std = None
45+
3846

3947
def get_numb_fparam(self) :
4048
return self.numb_fparam
4149

4250
def compute_dstats(self, all_stat, protection):
4351
# stat fparam
4452
if self.numb_fparam > 0:
45-
stat = np.zeros([self.numb_fparam])
4653
cat_data = np.concatenate(all_stat['fparam'], axis = 0)
47-
self.avg = np.average(cat_data, axis = 0)
48-
self.std = np.std(cat_data, axis = 0)
49-
for ii in range(self.std.size):
50-
if self.std[ii] < protection:
51-
self.std[ii] = protection
52-
self.inv_std = 1./self.std
54+
self.fparam_avg = np.average(cat_data, axis = 0)
55+
self.fparam_std = np.std(cat_data, axis = 0)
56+
for ii in range(self.fparam_std.size):
57+
if self.fparam_std[ii] < protection:
58+
self.fparam_std[ii] = protection
59+
self.fparam_inv_std = 1./self.fparam_std
5360

5461
def build (self,
5562
inputs,
@@ -58,7 +65,7 @@ def build (self,
5865
bias_atom_e = None,
5966
reuse = None,
6067
suffix = '') :
61-
if self.numb_fparam > 0 and ( self.avg is None or self.inv_std is None ):
68+
if self.numb_fparam > 0 and ( self.fparam_avg is None or self.fparam_inv_std is None ):
6269
raise RuntimeError('No data stat result. one should do data statisitic, before build')
6370

6471
with tf.variable_scope('fitting_attr' + suffix, reuse = reuse) :
@@ -70,12 +77,12 @@ def build (self,
7077
self.numb_fparam,
7178
dtype = global_tf_float_precision,
7279
trainable = False,
73-
initializer = tf.constant_initializer(self.avg))
80+
initializer = tf.constant_initializer(self.fparam_avg))
7481
t_fparam_istd = tf.get_variable('t_fparam_istd',
7582
self.numb_fparam,
7683
dtype = global_tf_float_precision,
7784
trainable = False,
78-
initializer = tf.constant_initializer(self.inv_std))
85+
initializer = tf.constant_initializer(self.fparam_inv_std))
7986

8087
start_index = 0
8188
inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]])

0 commit comments

Comments
 (0)