Skip to content

Commit 6b3d014

Browse files
authored
Merge pull request #237 from amcadmus/devel
add one-sided embedding net
2 parents 604ced7 + f280cc0 commit 6b3d014

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

source/train/DescrptSeA.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__ (self, jdata):
1717
.add('resnet_dt',bool, default = False) \
1818
.add('trainable',bool, default = True) \
1919
.add('seed', int) \
20+
.add('type_one_side', bool, default = False) \
2021
.add('exclude_types', list, default = []) \
2122
.add('set_davg_zero', bool, default = False) \
2223
.add('activation_function', str, default = 'tanh') \
@@ -39,6 +40,9 @@ def __init__ (self, jdata):
3940
self.exclude_types.add((tt[0], tt[1]))
4041
self.exclude_types.add((tt[1], tt[0]))
4142
self.set_davg_zero = class_data['set_davg_zero']
43+
self.type_one_side = class_data['type_one_side']
44+
if self.type_one_side and len(exclude_types) != 0:
45+
raise RuntimeError('"type_one_side" is not compatible with "exclude_types"')
4246

4347
# descrpt config
4448
self.sel_r = [ 0 for ii in range(len(self.sel_a)) ]
@@ -244,17 +248,27 @@ def _pass_filter(self,
244248
inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
245249
output = []
246250
output_qmat = []
247-
for type_i in range(self.ntypes):
248-
inputs_i = tf.slice (inputs,
249-
[ 0, start_index* self.ndescrpt],
250-
[-1, natoms[2+type_i]* self.ndescrpt] )
251+
if not self.type_one_side:
252+
for type_i in range(self.ntypes):
253+
inputs_i = tf.slice (inputs,
254+
[ 0, start_index* self.ndescrpt],
255+
[-1, natoms[2+type_i]* self.ndescrpt] )
256+
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
257+
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
258+
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
259+
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3])
260+
output.append(layer)
261+
output_qmat.append(qmat)
262+
start_index += natoms[2+type_i]
263+
else :
264+
inputs_i = inputs
251265
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
252-
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
253-
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
254-
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3])
266+
type_i = -1
267+
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
268+
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
269+
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3])
255270
output.append(layer)
256271
output_qmat.append(qmat)
257-
start_index += natoms[2+type_i]
258272
output = tf.concat(output, axis = 1)
259273
output_qmat = tf.concat(output_qmat, axis = 1)
260274
return output, output_qmat

0 commit comments

Comments
 (0)