Skip to content

Commit 1847b0d

Browse files
author
Han Wang
committed
scale tensor loss
1 parent 58610f3 commit 1847b0d

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

source/train/Fitting.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,12 @@ def __init__ (self, jdata, descrpt) :
395395
self.seed = class_data['seed']
396396
self.diag_shift = class_data['diag_shift']
397397
self.scale = class_data['scale']
398+
if type(self.sel_type) is not list:
399+
self.sel_type = [self.sel_type]
400+
if type(self.diag_shift) is not list:
401+
self.diag_shift = [self.diag_shift]
402+
if type(self.scale) is not list:
403+
self.scale = [self.scale]
398404
self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
399405
self.dim_rot_mat = self.dim_rot_mat_1 * 3
400406
self.useBN = False

source/train/Loss.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ def __init__ (self, jdata, **kwarg) :
284284
self.tensor_size = kwarg['tensor_size']
285285
self.label_name = kwarg['label_name']
286286
self.atomic = kwarg.get('atomic', True)
287+
if jdata is not None:
288+
self.scale = jdata.get('scale', 1.0)
289+
else:
290+
self.scale = 1.0
287291
# data required
288292
add_data_requirement(self.label_name,
289293
self.tensor_size,
@@ -300,7 +304,7 @@ def build (self,
300304
suffix):
301305
polar_hat = label_dict[self.label_name]
302306
polar = model_dict[self.tensor_name]
303-
l2_loss = tf.reduce_mean( tf.square(polar - polar_hat), name='l2_'+suffix)
307+
l2_loss = tf.reduce_mean( tf.square(self.scale*(polar - polar_hat)), name='l2_'+suffix)
304308
if not self.atomic :
305309
atom_norm = 1./ global_cvt_2_tf_float(natoms[0])
306310
l2_loss = l2_loss * atom_norm

0 commit comments

Comments
 (0)