Skip to content

Commit 58610f3

Browse files
author
Han Wang
committed
scale and shift the polar output
1 parent fc78c79 commit 58610f3

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

source/train/Fitting.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,8 @@ def __init__ (self, jdata, descrpt) :
383383
.add('neuron', list, default = [120,120,120], alias = 'n_neuron')\
384384
.add('resnet_dt', bool, default = True)\
385385
.add('fit_diag', bool, default = True)\
386+
.add('diag_shift', [list,float], default = [0.0 for ii in range(self.ntypes)])\
387+
.add('scale', [list,float], default = [1.0 for ii in range(self.ntypes)])\
386388
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\
387389
.add('seed', int)
388390
class_data = args.parse(jdata)
@@ -391,6 +393,8 @@ def __init__ (self, jdata, descrpt) :
391393
self.sel_type = class_data['sel_type']
392394
self.fit_diag = class_data['fit_diag']
393395
self.seed = class_data['seed']
396+
self.diag_shift = class_data['diag_shift']
397+
self.scale = class_data['scale']
394398
self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
395399
self.dim_rot_mat = self.dim_rot_mat_1 * 3
396400
self.useBN = False
@@ -477,6 +481,10 @@ def build (self,
477481
final_layer = tf.matmul(rot_mat_i, final_layer, transpose_a = True)
478482
# nframes x natoms x 3 x 3
479483
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3, 3])
484+
# shift and scale
485+
sel_type_idx = self.sel_type.index(type_i)
486+
final_layer = final_layer * self.scale[sel_type_idx]
487+
final_layer = final_layer + self.diag_shift[sel_type_idx] * tf.eye(3, batch_shape=[tf.shape(inputs)[0], natoms[2+type_i]], dtype = global_tf_float_precision)
480488

481489
# concat the results
482490
if count == 0:

0 commit comments

Comments
 (0)