@@ -20,36 +20,43 @@ def __init__ (self, jdata, descrpt):
2020 self .dim_descrpt = descrpt .get_dim_out ()
2121 args = ClassArg ()\
2222 .add ('numb_fparam' , int , default = 0 )\
23+ .add ('numb_aparam' , int , default = 0 )\
2324 .add ('neuron' , list , default = [120 ,120 ,120 ], alias = 'n_neuron' )\
2425 .add ('resnet_dt' , bool , default = True )\
2526 .add ('seed' , int )
2627 class_data = args .parse (jdata )
2728 self .numb_fparam = class_data ['numb_fparam' ]
29+ self .numb_aparam = class_data ['numb_aparam' ]
2830 self .n_neuron = class_data ['neuron' ]
2931 self .resnet_dt = class_data ['resnet_dt' ]
3032 self .seed = class_data ['seed' ]
3133 self .useBN = False
3234 # data requirement
3335 if self .numb_fparam > 0 :
34- add_data_requirement ('fparam' , self .numb_fparam , atomic = False , must = False , high_prec = False )
35- self .avg = None
36- self .std = None
37- self .inv_std = None
36+ add_data_requirement ('fparam' , self .numb_fparam , atomic = False , must = True , high_prec = False )
37+ self .fparam_avg = None
38+ self .fparam_std = None
39+ self .fparam_inv_std = None
40+ if self .numb_aparam > 0 :
41+ add_data_requirement ('aparam' , self .numb_aparam , atomic = True , must = True , high_prec = False )
42+ self .aparam_avg = None
43+ self .aparam_std = None
44+ self .aparam_inv_std = None
45+
3846
3947 def get_numb_fparam (self ) :
4048 return self .numb_fparam
4149
4250 def compute_dstats (self , all_stat , protection ):
4351 # stat fparam
4452 if self .numb_fparam > 0 :
45- stat = np .zeros ([self .numb_fparam ])
4653 cat_data = np .concatenate (all_stat ['fparam' ], axis = 0 )
47- self .avg = np .average (cat_data , axis = 0 )
48- self .std = np .std (cat_data , axis = 0 )
49- for ii in range (self .std .size ):
50- if self .std [ii ] < protection :
51- self .std [ii ] = protection
52- self .inv_std = 1. / self .std
54+ self .fparam_avg = np .average (cat_data , axis = 0 )
55+ self .fparam_std = np .std (cat_data , axis = 0 )
56+ for ii in range (self .fparam_std .size ):
57+ if self .fparam_std [ii ] < protection :
58+ self .fparam_std [ii ] = protection
59+ self .fparam_inv_std = 1. / self .fparam_std
5360
5461 def build (self ,
5562 inputs ,
@@ -58,7 +65,7 @@ def build (self,
5865 bias_atom_e = None ,
5966 reuse = None ,
6067 suffix = '' ) :
61- if self .numb_fparam > 0 and ( self .avg is None or self .inv_std is None ):
68+ if self .numb_fparam > 0 and ( self .fparam_avg is None or self .fparam_inv_std is None ):
6269 raise RuntimeError ('No data stat result. one should do data statisitic, before build' )
6370
6471 with tf .variable_scope ('fitting_attr' + suffix , reuse = reuse ) :
@@ -70,12 +77,12 @@ def build (self,
7077 self .numb_fparam ,
7178 dtype = global_tf_float_precision ,
7279 trainable = False ,
73- initializer = tf .constant_initializer (self .avg ))
80+ initializer = tf .constant_initializer (self .fparam_avg ))
7481 t_fparam_istd = tf .get_variable ('t_fparam_istd' ,
7582 self .numb_fparam ,
7683 dtype = global_tf_float_precision ,
7784 trainable = False ,
78- initializer = tf .constant_initializer (self .inv_std ))
85+ initializer = tf .constant_initializer (self .fparam_inv_std ))
7986
8087 start_index = 0
8188 inputs = tf .reshape (inputs , [- 1 , self .dim_descrpt * natoms [0 ]])
0 commit comments