@@ -178,6 +178,100 @@ def print_on_training(self,
178178 return print_str
179179
180180
181+ class EnerDipoleLoss () :
182+ def __init__ (self , jdata , ** kwarg ) :
183+ self .starter_learning_rate = kwarg ['starter_learning_rate' ]
184+ args = ClassArg ()\
185+ .add ('start_pref_e' , float , must = True , default = 0.1 ) \
186+ .add ('limit_pref_e' , float , must = True , default = 1.00 )\
187+ .add ('start_pref_ed' , float , must = True , default = 1.00 )\
188+ .add ('limit_pref_ed' , float , must = True , default = 1.00 )
189+ class_data = args .parse (jdata )
190+ self .start_pref_e = class_data ['start_pref_e' ]
191+ self .limit_pref_e = class_data ['limit_pref_e' ]
192+ self .start_pref_ed = class_data ['start_pref_ed' ]
193+ self .limit_pref_ed = class_data ['limit_pref_ed' ]
194+ # data required
195+ add_data_requirement ('energy' , 1 , atomic = False , must = True , high_prec = True )
196+ add_data_requirement ('energy_dipole' , 3 , atomic = False , must = True , high_prec = False )
197+
198+ def build (self ,
199+ learning_rate ,
200+ natoms ,
201+ model_dict ,
202+ label_dict ,
203+ suffix ):
204+ coord = model_dict ['coord' ]
205+ energy = model_dict ['energy' ]
206+ atom_ener = model_dict ['atom_ener' ]
207+ nframes = tf .shape (atom_ener )[0 ]
208+ natoms = tf .shape (atom_ener )[1 ]
209+ # build energy dipole
210+ atom_ener0 = atom_ener - tf .reshape (tf .tile (tf .reshape (energy / global_cvt_2_ener_float (natoms ), [- 1 , 1 ]), [1 , natoms ]), [nframes , natoms ])
211+ coord = tf .reshape (coord , [nframes , natoms , 3 ])
212+ atom_ener0 = tf .reshape (atom_ener0 , [nframes , 1 , natoms ])
213+ ener_dipole = tf .matmul (atom_ener0 , coord )
214+ ener_dipole = tf .reshape (ener_dipole , [nframes , 3 ])
215+
216+ energy_hat = label_dict ['energy' ]
217+ ener_dipole_hat = label_dict ['energy_dipole' ]
218+ find_energy = label_dict ['find_energy' ]
219+ find_ener_dipole = label_dict ['find_energy_dipole' ]
220+
221+ l2_ener_loss = tf .reduce_mean ( tf .square (energy - energy_hat ), name = 'l2_' + suffix )
222+
223+ ener_dipole_reshape = tf .reshape (ener_dipole , [- 1 ])
224+ ener_dipole_hat_reshape = tf .reshape (ener_dipole_hat , [- 1 ])
225+ l2_ener_dipole_loss = tf .reduce_mean ( tf .square (ener_dipole_reshape - ener_dipole_hat_reshape ), name = 'l2_' + suffix )
226+
227+ # atom_norm_ener = 1./ global_cvt_2_ener_float(natoms[0])
228+ atom_norm_ener = 1. / global_cvt_2_ener_float (natoms )
229+ pref_e = global_cvt_2_ener_float (find_energy * (self .limit_pref_e + (self .start_pref_e - self .limit_pref_e ) * learning_rate / self .starter_learning_rate ) )
230+ pref_ed = global_cvt_2_tf_float (find_ener_dipole * (self .limit_pref_ed + (self .start_pref_ed - self .limit_pref_ed ) * learning_rate / self .starter_learning_rate ) )
231+
232+ l2_loss = 0
233+ more_loss = {}
234+ l2_loss += atom_norm_ener * (pref_e * l2_ener_loss )
235+ l2_loss += global_cvt_2_ener_float (pref_ed * l2_ener_dipole_loss )
236+ more_loss ['l2_ener_loss' ] = l2_ener_loss
237+ more_loss ['l2_ener_dipole_loss' ] = l2_ener_dipole_loss
238+
239+ self .l2_l = l2_loss
240+ self .l2_more = more_loss
241+ return l2_loss , more_loss
242+
243+
244+ def print_header (self ) :
245+ prop_fmt = ' %9s %9s'
246+ print_str = ''
247+ print_str += prop_fmt % ('l2_tst' , 'l2_trn' )
248+ print_str += prop_fmt % ('l2_e_tst' , 'l2_e_trn' )
249+ print_str += prop_fmt % ('l2_ed_tst' , 'l2_ed_trn' )
250+ return print_str
251+
252+
253+ def print_on_training (self ,
254+ sess ,
255+ natoms ,
256+ feed_dict_test ,
257+ feed_dict_batch ) :
258+ error_test , error_e_test , error_ed_test \
259+ = sess .run ([self .l2_l , \
260+ self .l2_more ['l2_ener_loss' ], \
261+ self .l2_more ['l2_ener_dipole_loss' ]],
262+ feed_dict = feed_dict_test )
263+ error_train , error_e_train , error_ed_train \
264+ = sess .run ([self .l2_l , \
265+ self .l2_more ['l2_ener_loss' ], \
266+ self .l2_more ['l2_ener_dipole_loss' ]],
267+ feed_dict = feed_dict_batch )
268+ print_str = ""
269+ prop_fmt = " %9.2e %9.2e"
270+ print_str += prop_fmt % (np .sqrt (error_test ), np .sqrt (error_train ))
271+ print_str += prop_fmt % (np .sqrt (error_e_test ) / natoms [0 ], np .sqrt (error_e_train ) / natoms [0 ])
272+ print_str += prop_fmt % (np .sqrt (error_ed_test ), np .sqrt (error_ed_train ))
273+ return print_str
274+
181275
182276class TensorLoss () :
183277 def __init__ (self , jdata , ** kwarg ) :
0 commit comments