@@ -23,7 +23,8 @@ def __init__ (self, jdata, descrpt):
2323 .add ('seed' , int ) \
2424 .add ('atom_ener' , list , default = [])\
2525 .add ("activation_function" , str , default = "tanh" )\
26- .add ("precision" , str , default = "default" )
26+ .add ("precision" , str , default = "default" )\
27+ .add ("trainable" , [list , bool ], default = True )
2728 class_data = args .parse (jdata )
2829 self .numb_fparam = class_data ['numb_fparam' ]
2930 self .numb_aparam = class_data ['numb_aparam' ]
@@ -32,7 +33,11 @@ def __init__ (self, jdata, descrpt):
3233 self .rcond = class_data ['rcond' ]
3334 self .seed = class_data ['seed' ]
3435 self .fitting_activation_fn = get_activation_func (class_data ["activation_function" ])
35- self .fitting_precision = get_precision (class_data ['precision' ])
36+ self .fitting_precision = get_precision (class_data ['precision' ])
37+ self .trainable = class_data ['trainable' ]
38+ if type (self .trainable ) is bool :
39+ self .trainable = [self .trainable ] * (len (self .n_neuron )+ 1 )
40+ assert (len (self .trainable ) == len (self .n_neuron ) + 1 ), 'length of trainable should be that of n_neuron + 1'
3641 self .atom_ener = []
3742 for at , ae in enumerate (class_data ['atom_ener' ]):
3843 if ae is not None :
@@ -205,10 +210,10 @@ def build (self,
205210
206211 for ii in range (0 ,len (self .n_neuron )) :
207212 if ii >= 1 and self .n_neuron [ii ] == self .n_neuron [ii - 1 ] :
208- layer += one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed , use_timestep = self .resnet_dt , activation_fn = self .fitting_activation_fn , precision = self .fitting_precision )
213+ layer += one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed , use_timestep = self .resnet_dt , activation_fn = self .fitting_activation_fn , precision = self .fitting_precision , trainable = self . trainable [ ii ] )
209214 else :
210- layer = one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed , precision = self .fitting_precision )
211- final_layer = one_layer (layer , 1 , activation_fn = None , bavg = type_bias_ae , name = 'final_layer_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed , precision = self .fitting_precision )
215+ layer = one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed , precision = self .fitting_precision , trainable = self . trainable [ ii ] )
216+ final_layer = one_layer (layer , 1 , activation_fn = None , bavg = type_bias_ae , name = 'final_layer_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed , precision = self .fitting_precision , trainable = self . trainable [ - 1 ] )
212217
213218 if type_i < len (self .atom_ener ) and self .atom_ener [type_i ] is not None :
214219 inputs_zero = tf .zeros_like (inputs_i , dtype = global_tf_float_precision )
@@ -219,10 +224,10 @@ def build (self,
219224 layer = tf .concat ([layer , ext_aparam ], axis = 1 )
220225 for ii in range (0 ,len (self .n_neuron )) :
221226 if ii >= 1 and self .n_neuron [ii ] == self .n_neuron [ii - 1 ] :
222- layer += one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = True , seed = self .seed , use_timestep = self .resnet_dt , activation_fn = self .fitting_activation_fn , precision = self .fitting_precision )
227+ layer += one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = True , seed = self .seed , use_timestep = self .resnet_dt , activation_fn = self .fitting_activation_fn , precision = self .fitting_precision , trainable = self . trainable [ ii ] )
223228 else :
224- layer = one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = True , seed = self .seed , activation_fn = self .fitting_activation_fn , precision = self .fitting_precision )
225- zero_layer = one_layer (layer , 1 , activation_fn = None , bavg = type_bias_ae , name = 'final_layer_type_' + str (type_i )+ suffix , reuse = True , seed = self .seed , precision = self .fitting_precision )
229+ layer = one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = True , seed = self .seed , activation_fn = self .fitting_activation_fn , precision = self .fitting_precision , trainable = self . trainable [ ii ] )
230+ zero_layer = one_layer (layer , 1 , activation_fn = None , bavg = type_bias_ae , name = 'final_layer_type_' + str (type_i )+ suffix , reuse = True , seed = self .seed , precision = self .fitting_precision , trainable = self . trainable [ - 1 ] )
226231 final_layer += self .atom_ener [type_i ] - zero_layer
227232
228233 final_layer = tf .reshape (final_layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ]])
0 commit comments