Skip to content

Commit e8bbee4

Browse files
author
Han Wang
committed
add energy dipole loss
1 parent c6aaf49 commit e8bbee4

File tree

3 files changed

+105
-3
lines changed

3 files changed

+105
-3
lines changed

source/train/Loss.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,100 @@ def print_on_training(self,
178178
return print_str
179179

180180

181+
class EnerDipoleLoss () :
182+
def __init__ (self, jdata, **kwarg) :
183+
self.starter_learning_rate = kwarg['starter_learning_rate']
184+
args = ClassArg()\
185+
.add('start_pref_e', float, must = True, default = 0.1) \
186+
.add('limit_pref_e', float, must = True, default = 1.00)\
187+
.add('start_pref_ed', float, must = True, default = 1.00)\
188+
.add('limit_pref_ed', float, must = True, default = 1.00)
189+
class_data = args.parse(jdata)
190+
self.start_pref_e = class_data['start_pref_e']
191+
self.limit_pref_e = class_data['limit_pref_e']
192+
self.start_pref_ed = class_data['start_pref_ed']
193+
self.limit_pref_ed = class_data['limit_pref_ed']
194+
# data required
195+
add_data_requirement('energy', 1, atomic=False, must=True, high_prec=True)
196+
add_data_requirement('energy_dipole', 3, atomic=False, must=True, high_prec=False)
197+
198+
def build (self,
199+
learning_rate,
200+
natoms,
201+
model_dict,
202+
label_dict,
203+
suffix):
204+
coord = model_dict['coord']
205+
energy = model_dict['energy']
206+
atom_ener = model_dict['atom_ener']
207+
nframes = tf.shape(atom_ener)[0]
208+
natoms = tf.shape(atom_ener)[1]
209+
# build energy dipole
210+
atom_ener0 = atom_ener - tf.reshape(tf.tile(tf.reshape(energy/global_cvt_2_ener_float(natoms), [-1, 1]), [1, natoms]), [nframes, natoms])
211+
coord = tf.reshape(coord, [nframes, natoms, 3])
212+
atom_ener0 = tf.reshape(atom_ener0, [nframes, 1, natoms])
213+
ener_dipole = tf.matmul(atom_ener0, coord)
214+
ener_dipole = tf.reshape(ener_dipole, [nframes, 3])
215+
216+
energy_hat = label_dict['energy']
217+
ener_dipole_hat = label_dict['energy_dipole']
218+
find_energy = label_dict['find_energy']
219+
find_ener_dipole = label_dict['find_energy_dipole']
220+
221+
l2_ener_loss = tf.reduce_mean( tf.square(energy - energy_hat), name='l2_'+suffix)
222+
223+
ener_dipole_reshape = tf.reshape(ener_dipole, [-1])
224+
ener_dipole_hat_reshape = tf.reshape(ener_dipole_hat, [-1])
225+
l2_ener_dipole_loss = tf.reduce_mean( tf.square(ener_dipole_reshape - ener_dipole_hat_reshape), name='l2_'+suffix)
226+
227+
# atom_norm_ener = 1./ global_cvt_2_ener_float(natoms[0])
228+
atom_norm_ener = 1./ global_cvt_2_ener_float(natoms)
229+
pref_e = global_cvt_2_ener_float(find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate) )
230+
pref_ed = global_cvt_2_tf_float(find_ener_dipole * (self.limit_pref_ed + (self.start_pref_ed - self.limit_pref_ed) * learning_rate / self.starter_learning_rate) )
231+
232+
l2_loss = 0
233+
more_loss = {}
234+
l2_loss += atom_norm_ener * (pref_e * l2_ener_loss)
235+
l2_loss += global_cvt_2_ener_float(pref_ed * l2_ener_dipole_loss)
236+
more_loss['l2_ener_loss'] = l2_ener_loss
237+
more_loss['l2_ener_dipole_loss'] = l2_ener_dipole_loss
238+
239+
self.l2_l = l2_loss
240+
self.l2_more = more_loss
241+
return l2_loss, more_loss
242+
243+
244+
def print_header(self) :
245+
prop_fmt = ' %9s %9s'
246+
print_str = ''
247+
print_str += prop_fmt % ('l2_tst', 'l2_trn')
248+
print_str += prop_fmt % ('l2_e_tst', 'l2_e_trn')
249+
print_str += prop_fmt % ('l2_ed_tst', 'l2_ed_trn')
250+
return print_str
251+
252+
253+
def print_on_training(self,
254+
sess,
255+
natoms,
256+
feed_dict_test,
257+
feed_dict_batch) :
258+
error_test, error_e_test, error_ed_test\
259+
= sess.run([self.l2_l, \
260+
self.l2_more['l2_ener_loss'], \
261+
self.l2_more['l2_ener_dipole_loss']],
262+
feed_dict=feed_dict_test)
263+
error_train, error_e_train, error_ed_train\
264+
= sess.run([self.l2_l, \
265+
self.l2_more['l2_ener_loss'], \
266+
self.l2_more['l2_ener_dipole_loss']],
267+
feed_dict=feed_dict_batch)
268+
print_str = ""
269+
prop_fmt = " %9.2e %9.2e"
270+
print_str += prop_fmt % (np.sqrt(error_test), np.sqrt(error_train))
271+
print_str += prop_fmt % (np.sqrt(error_e_test) / natoms[0], np.sqrt(error_e_train) / natoms[0])
272+
print_str += prop_fmt % (np.sqrt(error_ed_test), np.sqrt(error_ed_train))
273+
return print_str
274+
181275

182276
class TensorLoss () :
183277
def __init__ (self, jdata, **kwarg) :

source/train/Model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def build (self,
270270
model_dict['virial'] = virial
271271
model_dict['atom_ener'] = energy_raw
272272
model_dict['atom_virial'] = atom_virial
273+
model_dict['coord'] = coord
274+
model_dict['atype'] = atype
273275

274276
return model_dict
275277

source/train/Trainer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from deepmd.DescrptSeR import DescrptSeR
1919
from deepmd.DescrptSeAR import DescrptSeAR
2020
from deepmd.Model import Model, WFCModel, DipoleModel, PolarModel, GlobalPolarModel
21-
from deepmd.Loss import EnerStdLoss, TensorLoss
21+
from deepmd.Loss import EnerStdLoss, EnerDipoleLoss, TensorLoss
2222
from deepmd.LearningRate import LearningRateExp
2323

2424
from tensorflow.python.framework import ops
@@ -146,8 +146,15 @@ def _init_param(self, jdata):
146146
loss_param = jdata['loss']
147147
except:
148148
loss_param = None
149+
loss_type = loss_param.get('type', 'std')
150+
149151
if fitting_type == 'ener':
150-
self.loss = EnerStdLoss(loss_param, starter_learning_rate = self.lr.start_lr())
152+
if loss_type == 'std':
153+
self.loss = EnerStdLoss(loss_param, starter_learning_rate = self.lr.start_lr())
154+
elif loss_type == 'ener_dipole':
155+
self.loss = EnerDipoleLoss(loss_param, starter_learning_rate = self.lr.start_lr())
156+
else:
157+
raise RuntimeError('unknow loss type')
151158
elif fitting_type == 'wfc':
152159
self.loss = TensorLoss(loss_param,
153160
model = self.model,
@@ -262,7 +269,6 @@ def _build_network(self, data):
262269
self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name='t_natoms')
263270
self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name='t_mesh')
264271
self.place_holders['is_training'] = tf.placeholder(tf.bool)
265-
266272
self.model_pred\
267273
= self.model.build (self.place_holders['coord'],
268274
self.place_holders['type'],

0 commit comments

Comments
 (0)