Skip to content

Commit 5fb098c

Browse files
authored
Merge pull request #135 from amcadmus/devel
Devel
2 parents 968d8a0 + 90a5ded commit 5fb098c

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

source/train/DescrptSeA.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__ (self, jdata):
2828
.add('neuron', list, default = [10, 20, 40]) \
2929
.add('axis_neuron', int, default = 4, alias = 'n_axis_neuron') \
3030
.add('resnet_dt',bool, default = False) \
31+
.add('trainable',bool, default = True) \
3132
.add('seed', int)
3233
class_data = args.parse(jdata)
3334
self.sel_a = class_data['sel']
@@ -37,6 +38,7 @@ def __init__ (self, jdata):
3738
self.n_axis_neuron = class_data['axis_neuron']
3839
self.filter_resnet_dt = class_data['resnet_dt']
3940
self.seed = class_data['seed']
41+
self.trainable = class_data['trainable']
4042

4143
# descrpt config
4244
self.sel_r = [ 0 for ii in range(len(self.sel_a)) ]
@@ -167,7 +169,7 @@ def build (self,
167169

168170
self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
169171

170-
self.dout, self.qmat = self._pass_filter(self.descrpt_reshape, natoms, suffix = suffix, reuse = reuse)
172+
self.dout, self.qmat = self._pass_filter(self.descrpt_reshape, natoms, suffix = suffix, reuse = reuse, trainable = self.trainable)
171173

172174
return self.dout
173175

@@ -201,7 +203,8 @@ def _pass_filter(self,
201203
inputs,
202204
natoms,
203205
reuse = None,
204-
suffix = '') :
206+
suffix = '',
207+
trainable = True) :
205208
start_index = 0
206209
inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
207210
shape = inputs.get_shape().as_list()
@@ -212,7 +215,7 @@ def _pass_filter(self,
212215
[ 0, start_index* self.ndescrpt],
213216
[-1, natoms[2+type_i]* self.ndescrpt] )
214217
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
215-
layer, qmat = self._filter(inputs_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed)
218+
layer, qmat = self._filter(inputs_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable)
216219
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
217220
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3])
218221
output.append(layer)
@@ -297,7 +300,8 @@ def _filter(self,
297300
bavg=0.0,
298301
name='linear',
299302
reuse=None,
300-
seed=None):
303+
seed=None,
304+
trainable = True):
301305
# natom x (nei x 4)
302306
shape = inputs.get_shape().as_list()
303307
outputs_size = [1] + self.filter_neuron
@@ -320,16 +324,19 @@ def _filter(self,
320324
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
321325
[outputs_size[ii - 1], outputs_size[ii]],
322326
global_tf_float_precision,
323-
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed))
327+
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
328+
trainable = trainable)
324329
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
325330
[1, outputs_size[ii]],
326331
global_tf_float_precision,
327-
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed))
332+
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
333+
trainable = trainable)
328334
if self.filter_resnet_dt :
329335
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
330336
[1, outputs_size[ii]],
331337
global_tf_float_precision,
332-
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed))
338+
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
339+
trainable = trainable)
333340
if outputs_size[ii] == outputs_size[ii-1]:
334341
if self.filter_resnet_dt :
335342
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b) * idt
@@ -376,7 +383,8 @@ def _filter_type_ext(self,
376383
bavg=0.0,
377384
name='linear',
378385
reuse=None,
379-
seed=None):
386+
seed=None,
387+
trainable = True):
380388
# natom x (nei x 4)
381389
shape = inputs.get_shape().as_list()
382390
outputs_size = [1] + self.filter_neuron
@@ -401,16 +409,19 @@ def _filter_type_ext(self,
401409
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
402410
[outputs_size[ii - 1], outputs_size[ii]],
403411
global_tf_float_precision,
404-
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed))
412+
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
413+
trainable = trainable)
405414
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
406415
[1, outputs_size[ii]],
407416
global_tf_float_precision,
408-
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed))
417+
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
418+
trainable = trainable)
409419
if self.filter_resnet_dt :
410420
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
411421
[1, outputs_size[ii]],
412422
global_tf_float_precision,
413-
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed))
423+
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
424+
trainable = trainable)
414425
if outputs_size[ii] == outputs_size[ii-1]:
415426
if self.filter_resnet_dt :
416427
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b) * idt

source/train/DescrptSeR.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__ (self, jdata):
2727
.add('rcut_smth',float, default = 5.5) \
2828
.add('neuron', list, default = [10, 20, 40]) \
2929
.add('resnet_dt',bool, default = False) \
30+
.add('trainable',bool, default = True) \
3031
.add('seed', int)
3132
class_data = args.parse(jdata)
3233
self.sel_r = class_data['sel']
@@ -35,6 +36,7 @@ def __init__ (self, jdata):
3536
self.filter_neuron = class_data['neuron']
3637
self.filter_resnet_dt = class_data['resnet_dt']
3738
self.seed = class_data['seed']
39+
self.trainable = class_data['trainable']
3840

3941
# descrpt config
4042
self.sel_a = [ 0 for ii in range(len(self.sel_r)) ]
@@ -145,7 +147,7 @@ def build (self,
145147

146148
self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
147149

148-
self.dout = self._pass_filter(self.descrpt_reshape, natoms, suffix = suffix, reuse = reuse)
150+
self.dout = self._pass_filter(self.descrpt_reshape, natoms, suffix = suffix, reuse = reuse, trainable = self.trainable)
149151

150152
return self.dout
151153

@@ -171,7 +173,8 @@ def _pass_filter(self,
171173
inputs,
172174
natoms,
173175
reuse = None,
174-
suffix = '') :
176+
suffix = '',
177+
trainable = True) :
175178
start_index = 0
176179
inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
177180
shape = inputs.get_shape().as_list()
@@ -181,7 +184,7 @@ def _pass_filter(self,
181184
[ 0, start_index* self.ndescrpt],
182185
[-1, natoms[2+type_i]* self.ndescrpt] )
183186
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
184-
layer = self._filter_r(inputs_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed)
187+
layer = self._filter_r(inputs_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable)
185188
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
186189
output.append(layer)
187190
start_index += natoms[2+type_i]
@@ -253,7 +256,8 @@ def _filter_r(self,
253256
bavg=0.0,
254257
name='linear',
255258
reuse=None,
256-
seed=None):
259+
seed=None,
260+
trainable = True):
257261
# natom x nei
258262
shape = inputs.get_shape().as_list()
259263
outputs_size = [1] + self.filter_neuron
@@ -274,16 +278,19 @@ def _filter_r(self,
274278
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
275279
[outputs_size[ii - 1], outputs_size[ii]],
276280
global_tf_float_precision,
277-
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed))
281+
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
282+
trainable = trainable)
278283
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
279284
[1, outputs_size[ii]],
280285
global_tf_float_precision,
281-
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed))
286+
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
287+
trainable = trainable)
282288
if self.filter_resnet_dt :
283289
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
284290
[1, outputs_size[ii]],
285291
global_tf_float_precision,
286-
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed))
292+
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
293+
trainable = trainable)
287294
if outputs_size[ii] == outputs_size[ii-1]:
288295
if self.filter_resnet_dt :
289296
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b) * idt

source/train/Trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def train (self,
386386
fp = open(self.disp_file, "a")
387387

388388
cur_batch = self.sess.run(self.global_step)
389+
is_first_step = True
389390
self.cur_batch = cur_batch
390391
self.run_opt.message("start training at lr %.2e (== %.2e), final lr will be %.2e" %
391392
(self.sess.run(self.learning_rate),
@@ -417,8 +418,9 @@ def train (self,
417418
feed_dict_batch[self.place_holders[ii]] = batch_data[ii]
418419
feed_dict_batch[self.place_holders['is_training']] = True
419420

420-
if self.display_in_training and cur_batch == 0 :
421+
if self.display_in_training and is_first_step :
421422
self.test_on_the_fly(fp, data, feed_dict_batch)
423+
is_first_step = False
422424
if self.timing_in_training : tic = time.time()
423425
self.sess.run([self.train_op], feed_dict = feed_dict_batch, options=prf_options, run_metadata=prf_run_metadata)
424426
if self.timing_in_training : toc = time.time()

0 commit comments

Comments
 (0)