Skip to content

Commit 9ec295e

Browse files
authored
Merge pull request #171 from njzjz/typemask
allows no interaction and supports user-specified atomic energy
2 parents a370fbf + be68d16 commit 9ec295e

File tree

3 files changed

+117
-64
lines changed

3 files changed

+117
-64
lines changed

source/train/DescrptSeA.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def __init__ (self, jdata):
1616
.add('axis_neuron', int, default = 4, alias = 'n_axis_neuron') \
1717
.add('resnet_dt',bool, default = False) \
1818
.add('trainable',bool, default = True) \
19-
.add('seed', int)
19+
.add('seed', int) \
20+
.add('exclude_types', list, default = []) \
21+
.add('set_davg_zero', bool, default = False)
2022
class_data = args.parse(jdata)
2123
self.sel_a = class_data['sel']
2224
self.rcut_r = class_data['rcut']
@@ -26,6 +28,13 @@ def __init__ (self, jdata):
2628
self.filter_resnet_dt = class_data['resnet_dt']
2729
self.seed = class_data['seed']
2830
self.trainable = class_data['trainable']
31+
exclude_types = class_data['exclude_types']
32+
self.exclude_types = set()
33+
for tt in exclude_types:
34+
assert(len(tt) == 2)
35+
self.exclude_types.add((tt[0], tt[1]))
36+
self.exclude_types.add((tt[1], tt[0]))
37+
self.set_davg_zero = class_data['set_davg_zero']
2938

3039
# descrpt config
3140
self.sel_r = [ 0 for ii in range(len(self.sel_a)) ]
@@ -124,7 +133,8 @@ def compute_input_stats (self,
124133
all_davg.append(davg)
125134
all_dstd.append(dstd)
126135

127-
self.davg = np.array(all_davg)
136+
if not self.set_davg_zero:
137+
self.davg = np.array(all_davg)
128138
self.dstd = np.array(all_dstd)
129139

130140

@@ -235,7 +245,7 @@ def _pass_filter(self,
235245
[ 0, start_index* self.ndescrpt],
236246
[-1, natoms[2+type_i]* self.ndescrpt] )
237247
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
238-
layer, qmat = self._filter(inputs_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable)
248+
layer, qmat = self._filter(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable)
239249
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
240250
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3])
241251
output.append(layer)
@@ -300,6 +310,7 @@ def _compute_std (self,sumv2, sumv, sumn) :
300310

301311
def _filter(self,
302312
inputs,
313+
type_input,
303314
natoms,
304315
activation_fn=tf.nn.tanh,
305316
stddev=1.0,
@@ -326,35 +337,39 @@ def _filter(self,
326337
# with (natom x nei_type_i) x 4
327338
inputs_reshape = tf.reshape(inputs_i, [-1, 4])
328339
xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0,0],[-1,1]),[-1,1])
329-
for ii in range(1, len(outputs_size)):
330-
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
331-
[outputs_size[ii - 1], outputs_size[ii]],
332-
global_tf_float_precision,
333-
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
334-
trainable = trainable)
335-
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
336-
[1, outputs_size[ii]],
337-
global_tf_float_precision,
338-
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
339-
trainable = trainable)
340-
if self.filter_resnet_dt :
341-
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
342-
[1, outputs_size[ii]],
343-
global_tf_float_precision,
344-
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
345-
trainable = trainable)
346-
if outputs_size[ii] == outputs_size[ii-1]:
347-
if self.filter_resnet_dt :
348-
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b) * idt
349-
else :
350-
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b)
351-
elif outputs_size[ii] == outputs_size[ii-1] * 2:
352-
if self.filter_resnet_dt :
353-
xyz_scatter = tf.concat([xyz_scatter,xyz_scatter], 1) + activation_fn(tf.matmul(xyz_scatter, w) + b) * idt
354-
else :
355-
xyz_scatter = tf.concat([xyz_scatter,xyz_scatter], 1) + activation_fn(tf.matmul(xyz_scatter, w) + b)
356-
else:
357-
xyz_scatter = activation_fn(tf.matmul(xyz_scatter, w) + b)
340+
if (type_input, type_i) not in self.exclude_types:
341+
for ii in range(1, len(outputs_size)):
342+
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
343+
[outputs_size[ii - 1], outputs_size[ii]],
344+
global_tf_float_precision,
345+
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
346+
trainable = trainable)
347+
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
348+
[1, outputs_size[ii]],
349+
global_tf_float_precision,
350+
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
351+
trainable = trainable)
352+
if self.filter_resnet_dt :
353+
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
354+
[1, outputs_size[ii]],
355+
global_tf_float_precision,
356+
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
357+
trainable = trainable)
358+
if outputs_size[ii] == outputs_size[ii-1]:
359+
if self.filter_resnet_dt :
360+
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b) * idt
361+
else :
362+
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b)
363+
elif outputs_size[ii] == outputs_size[ii-1] * 2:
364+
if self.filter_resnet_dt :
365+
xyz_scatter = tf.concat([xyz_scatter,xyz_scatter], 1) + activation_fn(tf.matmul(xyz_scatter, w) + b) * idt
366+
else :
367+
xyz_scatter = tf.concat([xyz_scatter,xyz_scatter], 1) + activation_fn(tf.matmul(xyz_scatter, w) + b)
368+
else:
369+
xyz_scatter = activation_fn(tf.matmul(xyz_scatter, w) + b)
370+
else:
371+
w = tf.zeros((outputs_size[0], outputs_size[-1]), dtype=global_tf_float_precision)
372+
xyz_scatter = tf.matmul(xyz_scatter, w)
358373
# natom x nei_type_i x out_size
359374
xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1]))
360375
xyz_scatter_total.append(xyz_scatter)

source/train/DescrptSeR.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def __init__ (self, jdata):
1515
.add('neuron', list, default = [10, 20, 40]) \
1616
.add('resnet_dt',bool, default = False) \
1717
.add('trainable',bool, default = True) \
18-
.add('seed', int)
18+
.add('seed', int) \
19+
.add('exclude_types', list, default = []) \
20+
.add('set_davg_zero', bool, default = False)
1921
class_data = args.parse(jdata)
2022
self.sel_r = class_data['sel']
2123
self.rcut = class_data['rcut']
@@ -24,6 +26,13 @@ def __init__ (self, jdata):
2426
self.filter_resnet_dt = class_data['resnet_dt']
2527
self.seed = class_data['seed']
2628
self.trainable = class_data['trainable']
29+
exclude_types = class_data['exclude_types']
30+
self.exclude_types = set()
31+
for tt in exclude_types:
32+
assert(len(tt) == 2)
33+
self.exclude_types.add((tt[0], tt[1]))
34+
self.exclude_types.add((tt[1], tt[0]))
35+
self.set_davg_zero = class_data['set_davg_zero']
2736

2837
# descrpt config
2938
self.sel_a = [ 0 for ii in range(len(self.sel_r)) ]
@@ -104,7 +113,8 @@ def compute_input_stats (self,
104113
all_davg.append(davg)
105114
all_dstd.append(dstd)
106115

107-
self.davg = np.array(all_davg)
116+
if not self.set_davg_zero:
117+
self.davg = np.array(all_davg)
108118
self.dstd = np.array(all_dstd)
109119

110120

@@ -194,7 +204,7 @@ def _pass_filter(self,
194204
[ 0, start_index* self.ndescrpt],
195205
[-1, natoms[2+type_i]* self.ndescrpt] )
196206
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
197-
layer = self._filter_r(inputs_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable)
207+
layer = self._filter_r(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable)
198208
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
199209
output.append(layer)
200210
start_index += natoms[2+type_i]
@@ -248,6 +258,7 @@ def _compute_std (self,sumv2, sumv, sumn) :
248258

249259
def _filter_r(self,
250260
inputs,
261+
type_input,
251262
natoms,
252263
activation_fn=tf.nn.tanh,
253264
stddev=1.0,
@@ -271,35 +282,39 @@ def _filter_r(self,
271282
shape_i = inputs_i.get_shape().as_list()
272283
# with (natom x nei_type_i) x 1
273284
xyz_scatter = tf.reshape(inputs_i, [-1, 1])
274-
for ii in range(1, len(outputs_size)):
275-
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
276-
[outputs_size[ii - 1], outputs_size[ii]],
277-
global_tf_float_precision,
278-
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
279-
trainable = trainable)
280-
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
281-
[1, outputs_size[ii]],
282-
global_tf_float_precision,
283-
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
284-
trainable = trainable)
285-
if self.filter_resnet_dt :
286-
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
287-
[1, outputs_size[ii]],
288-
global_tf_float_precision,
289-
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
290-
trainable = trainable)
291-
if outputs_size[ii] == outputs_size[ii-1]:
285+
if (type_input, type_i) not in self.exclude_types:
286+
for ii in range(1, len(outputs_size)):
287+
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
288+
[outputs_size[ii - 1], outputs_size[ii]],
289+
global_tf_float_precision,
290+
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
291+
trainable = trainable)
292+
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
293+
[1, outputs_size[ii]],
294+
global_tf_float_precision,
295+
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
296+
trainable = trainable)
292297
if self.filter_resnet_dt :
293-
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b) * idt
294-
else :
295-
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b)
296-
elif outputs_size[ii] == outputs_size[ii-1] * 2:
297-
if self.filter_resnet_dt :
298-
xyz_scatter = tf.concat([xyz_scatter,xyz_scatter], 1) + activation_fn(tf.matmul(xyz_scatter, w) + b) * idt
299-
else :
300-
xyz_scatter = tf.concat([xyz_scatter,xyz_scatter], 1) + activation_fn(tf.matmul(xyz_scatter, w) + b)
301-
else:
302-
xyz_scatter = activation_fn(tf.matmul(xyz_scatter, w) + b)
298+
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
299+
[1, outputs_size[ii]],
300+
global_tf_float_precision,
301+
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
302+
trainable = trainable)
303+
if outputs_size[ii] == outputs_size[ii-1]:
304+
if self.filter_resnet_dt :
305+
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b) * idt
306+
else :
307+
xyz_scatter += activation_fn(tf.matmul(xyz_scatter, w) + b)
308+
elif outputs_size[ii] == outputs_size[ii-1] * 2:
309+
if self.filter_resnet_dt :
310+
xyz_scatter = tf.concat([xyz_scatter,xyz_scatter], 1) + activation_fn(tf.matmul(xyz_scatter, w) + b) * idt
311+
else :
312+
xyz_scatter = tf.concat([xyz_scatter,xyz_scatter], 1) + activation_fn(tf.matmul(xyz_scatter, w) + b)
313+
else:
314+
xyz_scatter = activation_fn(tf.matmul(xyz_scatter, w) + b)
315+
else:
316+
w = tf.zeros((outputs_size[0], outputs_size[-1]), dtype=global_tf_float_precision)
317+
xyz_scatter = tf.matmul(xyz_scatter, w)
303318
# natom x nei_type_i x out_size
304319
xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1], outputs_size[-1]))
305320
xyz_scatter_total.append(xyz_scatter)

source/train/Fitting.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,21 @@ def __init__ (self, jdata, descrpt):
2020
.add('neuron', list, default = [120,120,120], alias = 'n_neuron')\
2121
.add('resnet_dt', bool, default = True)\
2222
.add('rcond', float, default = 1e-3) \
23-
.add('seed', int)
23+
.add('seed', int) \
24+
.add('atom_ener', list, default = [])
2425
class_data = args.parse(jdata)
2526
self.numb_fparam = class_data['numb_fparam']
2627
self.numb_aparam = class_data['numb_aparam']
2728
self.n_neuron = class_data['neuron']
2829
self.resnet_dt = class_data['resnet_dt']
2930
self.rcond = class_data['rcond']
3031
self.seed = class_data['seed']
32+
self.atom_ener = []
33+
for at, ae in enumerate(class_data['atom_ener']):
34+
if ae is not None:
35+
self.atom_ener.append(tf.constant(ae, global_tf_float_precision, name = "atom_%d_ener" % at))
36+
else:
37+
self.atom_ener.append(None)
3138
self.useBN = False
3239
self.bias_atom_e = None
3340
# data requirement
@@ -198,6 +205,22 @@ def build (self,
198205
else :
199206
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
200207
final_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
208+
209+
if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None:
210+
inputs_zero = tf.zeros_like(inputs_i, dtype=global_tf_float_precision)
211+
layer = inputs_zero
212+
if self.numb_fparam > 0 :
213+
layer = tf.concat([layer, ext_fparam], axis = 1)
214+
if self.numb_aparam > 0 :
215+
layer = tf.concat([layer, ext_aparam], axis = 1)
216+
for ii in range(0,len(self.n_neuron)) :
217+
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
218+
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=True, seed = self.seed, use_timestep = self.resnet_dt)
219+
else :
220+
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=True, seed = self.seed)
221+
zero_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=True, seed = self.seed)
222+
final_layer += self.atom_ener[type_i] - zero_layer
223+
201224
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i]])
202225

203226
# concat the results

0 commit comments

Comments
 (0)