Skip to content

Commit ba3ed02

Browse files
author
Han Wang
committed
set trainable of fitting layers
1 parent 7918287 commit ba3ed02

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

source/train/Fitting.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def __init__ (self, jdata, descrpt):
2323
.add('seed', int) \
2424
.add('atom_ener', list, default = [])\
2525
.add("activation_function", str, default = "tanh")\
26-
.add("precision", str, default = "default")
26+
.add("precision", str, default = "default")\
27+
.add("trainable", [list, bool], default = True)
2728
class_data = args.parse(jdata)
2829
self.numb_fparam = class_data['numb_fparam']
2930
self.numb_aparam = class_data['numb_aparam']
@@ -32,7 +33,11 @@ def __init__ (self, jdata, descrpt):
3233
self.rcond = class_data['rcond']
3334
self.seed = class_data['seed']
3435
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
35-
self.fitting_precision = get_precision(class_data['precision'])
36+
self.fitting_precision = get_precision(class_data['precision'])
37+
self.trainable = class_data['trainable']
38+
if type(self.trainable) is bool:
39+
self.trainable = [self.trainable] * (len(self.n_neuron)+1)
40+
assert(len(self.trainable) == len(self.n_neuron) + 1), 'length of trainable should be that of n_neuron + 1'
3641
self.atom_ener = []
3742
for at, ae in enumerate(class_data['atom_ener']):
3843
if ae is not None:
@@ -205,10 +210,10 @@ def build (self,
205210

206211
for ii in range(0,len(self.n_neuron)) :
207212
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
208-
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision)
213+
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])
209214
else :
210-
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision)
211-
final_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision)
215+
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[ii])
216+
final_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[-1])
212217

213218
if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None:
214219
inputs_zero = tf.zeros_like(inputs_i, dtype=global_tf_float_precision)
@@ -219,10 +224,10 @@ def build (self,
219224
layer = tf.concat([layer, ext_aparam], axis = 1)
220225
for ii in range(0,len(self.n_neuron)) :
221226
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
222-
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=True, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision)
227+
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=True, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])
223228
else :
224-
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=True, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision)
225-
zero_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=True, seed = self.seed, precision = self.fitting_precision)
229+
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=True, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])
230+
zero_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=True, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[-1])
226231
final_layer += self.atom_ener[type_i] - zero_layer
227232

228233
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i]])

source/train/Network.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,27 @@ def one_layer(inputs,
1313
reuse=None,
1414
seed=None,
1515
use_timestep = False,
16+
trainable = True,
1617
useBN = False):
1718
with tf.variable_scope(name, reuse=reuse):
1819
shape = inputs.get_shape().as_list()
1920
w = tf.get_variable('matrix',
2021
[shape[1], outputs_size],
2122
precision,
22-
tf.random_normal_initializer(stddev=stddev/np.sqrt(shape[1]+outputs_size), seed = seed))
23+
tf.random_normal_initializer(stddev=stddev/np.sqrt(shape[1]+outputs_size), seed = seed),
24+
trainable = trainable)
2325
b = tf.get_variable('bias',
2426
[outputs_size],
2527
precision,
26-
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed))
28+
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
29+
trainable = trainable)
2730
hidden = tf.matmul(inputs, w) + b
2831
if activation_fn != None and use_timestep :
2932
idt = tf.get_variable('idt',
3033
[outputs_size],
3134
precision,
32-
tf.random_normal_initializer(stddev=0.001, mean = 0.1, seed = seed))
35+
tf.random_normal_initializer(stddev=0.001, mean = 0.1, seed = seed),
36+
trainable = trainable)
3337
if activation_fn != None:
3438
if useBN:
3539
None

0 commit comments

Comments
 (0)