Skip to content

Commit d17c857

Browse files
authored
Merge pull request #106 from amcadmus/devel
relative force loss
2 parents a8578b7 + 5deb126 commit d17c857

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

source/train/Loss.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__ (self, jdata, **kwarg) :
2222
.add('start_pref_ae', float, default = 0)\
2323
.add('limit_pref_ae', float, default = 0)\
2424
.add('start_pref_pf', float, default = 0)\
25-
.add('limit_pref_pf', float, default = 0)
25+
.add('limit_pref_pf', float, default = 0)\
26+
.add('relative_f', float)
2627
class_data = args.parse(jdata)
2728
self.start_pref_e = class_data['start_pref_e']
2829
self.limit_pref_e = class_data['limit_pref_e']
@@ -34,6 +35,7 @@ def __init__ (self, jdata, **kwarg) :
3435
self.limit_pref_ae = class_data['limit_pref_ae']
3536
self.start_pref_pf = class_data['start_pref_pf']
3637
self.limit_pref_pf = class_data['limit_pref_pf']
38+
self.relative_f = class_data['relative_f']
3739
self.has_e = (self.start_pref_e != 0 or self.limit_pref_e != 0)
3840
self.has_f = (self.start_pref_f != 0 or self.limit_pref_f != 0)
3941
self.has_v = (self.start_pref_v != 0 or self.limit_pref_v != 0)
@@ -72,8 +74,15 @@ def build (self,
7274
force_reshape = tf.reshape (force, [-1])
7375
force_hat_reshape = tf.reshape (force_hat, [-1])
7476
atom_pref_reshape = tf.reshape (atom_pref, [-1])
75-
l2_force_loss = tf.reduce_mean (tf.square(force_hat_reshape - force_reshape), name = "l2_force_" + suffix)
76-
l2_pref_force_loss = tf.reduce_mean (tf.multiply(tf.square(force_hat_reshape - force_reshape), atom_pref_reshape), name = "l2_pref_force_" + suffix)
77+
diff_f = force_hat_reshape - force_reshape
78+
if self.relative_f is not None:
79+
force_hat_3 = tf.reshape(force_hat, [-1, 3])
80+
norm_f = tf.reshape(tf.norm(force_hat_3, axis = 1), [-1, 1]) + self.relative_f
81+
diff_f_3 = tf.reshape(diff_f, [-1, 3])
82+
diff_f_3 = diff_f_3 / norm_f
83+
diff_f = tf.reshape(diff_f_3, [-1])
84+
l2_force_loss = tf.reduce_mean(tf.square(diff_f), name = "l2_force_" + suffix)
85+
l2_pref_force_loss = tf.reduce_mean(tf.multiply(tf.square(diff_f), atom_pref_reshape), name = "l2_pref_force_" + suffix)
7786

7887
virial_reshape = tf.reshape (virial, [-1])
7988
virial_hat_reshape = tf.reshape (virial_hat, [-1])

0 commit comments

Comments
 (0)