@@ -22,7 +22,8 @@ def __init__ (self, jdata, **kwarg) :
2222 .add ('start_pref_ae' , float , default = 0 )\
2323 .add ('limit_pref_ae' , float , default = 0 )\
2424 .add ('start_pref_pf' , float , default = 0 )\
25- .add ('limit_pref_pf' , float , default = 0 )
25+ .add ('limit_pref_pf' , float , default = 0 )\
26+ .add ('relative_f' , float )
2627 class_data = args .parse (jdata )
2728 self .start_pref_e = class_data ['start_pref_e' ]
2829 self .limit_pref_e = class_data ['limit_pref_e' ]
@@ -34,6 +35,7 @@ def __init__ (self, jdata, **kwarg) :
3435 self .limit_pref_ae = class_data ['limit_pref_ae' ]
3536 self .start_pref_pf = class_data ['start_pref_pf' ]
3637 self .limit_pref_pf = class_data ['limit_pref_pf' ]
38+ self .relative_f = class_data ['relative_f' ]
3739 self .has_e = (self .start_pref_e != 0 or self .limit_pref_e != 0 )
3840 self .has_f = (self .start_pref_f != 0 or self .limit_pref_f != 0 )
3941 self .has_v = (self .start_pref_v != 0 or self .limit_pref_v != 0 )
@@ -72,8 +74,15 @@ def build (self,
7274 force_reshape = tf .reshape (force , [- 1 ])
7375 force_hat_reshape = tf .reshape (force_hat , [- 1 ])
7476 atom_pref_reshape = tf .reshape (atom_pref , [- 1 ])
75- l2_force_loss = tf .reduce_mean (tf .square (force_hat_reshape - force_reshape ), name = "l2_force_" + suffix )
76- l2_pref_force_loss = tf .reduce_mean (tf .multiply (tf .square (force_hat_reshape - force_reshape ), atom_pref_reshape ), name = "l2_pref_force_" + suffix )
77+ diff_f = force_hat_reshape - force_reshape
78+ if self .relative_f is not None :
79+ force_hat_3 = tf .reshape (force_hat , [- 1 , 3 ])
80+ norm_f = tf .reshape (tf .norm (force_hat_3 , axis = 1 ), [- 1 , 1 ]) + self .relative_f
81+ diff_f_3 = tf .reshape (diff_f , [- 1 , 3 ])
82+ diff_f_3 = diff_f_3 / norm_f
83+ diff_f = tf .reshape (diff_f_3 , [- 1 ])
84+ l2_force_loss = tf .reduce_mean (tf .square (diff_f ), name = "l2_force_" + suffix )
85+ l2_pref_force_loss = tf .reduce_mean (tf .multiply (tf .square (diff_f ), atom_pref_reshape ), name = "l2_pref_force_" + suffix )
7786
7887 virial_reshape = tf .reshape (virial , [- 1 ])
7988 virial_hat_reshape = tf .reshape (virial_hat , [- 1 ])
0 commit comments