Skip to content

Commit 5061492

Browse files
author
Han Wang
committed
global_polar: print loss not normalized by sqrt(natoms). add dp test for global_polar
1 parent 846a78e commit 5061492

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

source/train/Loss.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,12 @@ def build (self,
301301
polar_hat = label_dict[self.label_name]
302302
polar = model_dict[self.tensor_name]
303303
l2_loss = tf.reduce_mean( tf.square(self.scale*(polar - polar_hat)), name='l2_'+suffix)
304+
more_loss = {'nonorm': l2_loss}
304305
if not self.atomic :
305306
atom_norm = 1./ global_cvt_2_tf_float(natoms[0])
306307
l2_loss = l2_loss * atom_norm
307308
self.l2_l = l2_loss
308-
more_loss = {}
309+
self.l2_more = more_loss['nonorm']
309310

310311
return l2_loss, more_loss
311312

@@ -321,10 +322,10 @@ def print_on_training(self,
321322
feed_dict_test,
322323
feed_dict_batch) :
323324
error_test\
324-
= sess.run([self.l2_l], \
325+
= sess.run([self.l2_more], \
325326
feed_dict=feed_dict_test)
326327
error_train\
327-
= sess.run([self.l2_l], \
328+
= sess.run([self.l2_more], \
328329
feed_dict=feed_dict_batch)
329330
print_str = ""
330331
prop_fmt = " %9.2e %9.2e"

source/train/test.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from deepmd import DeepPot
1313
from deepmd import DeepDipole
1414
from deepmd import DeepPolar
15+
from deepmd import DeepGlobalPolar
1516
from deepmd import DeepWFC
1617
from tensorflow.python.framework import ops
1718

@@ -28,6 +29,8 @@ def test (args):
2829
dp = DeepDipole(args.model)
2930
elif de.model_type == 'polar':
3031
dp = DeepPolar(args.model)
32+
elif de.model_type == 'global_polar':
33+
dp = DeepGlobalPolar(args.model)
3134
elif de.model_type == 'wfc':
3235
dp = DeepWFC(args.model)
3336
else :
@@ -41,7 +44,9 @@ def test (args):
4144
elif de.model_type == 'dipole':
4245
err, siz = test_dipole(dp, args)
4346
elif de.model_type == 'polar':
44-
err, siz = test_polar(dp, args)
47+
err, siz = test_polar(dp, args, global_polar=False)
48+
elif de.model_type == 'global_polar':
49+
err, siz = test_polar(dp, args, global_polar=True)
4550
elif de.model_type == 'wfc':
4651
err, siz = test_wfc(dp, args)
4752
else :
@@ -61,6 +66,8 @@ def test (args):
6166
print_dipole_sys_avg(avg_err)
6267
elif de.model_type == 'polar':
6368
print_polar_sys_avg(avg_err)
69+
elif de.model_type == 'global_polar':
70+
print_polar_sys_avg(avg_err)
6471
elif de.model_type == 'wfc':
6572
print_wfc_sys_avg(avg_err)
6673
else :
@@ -223,12 +230,15 @@ def print_wfc_sys_avg(avg):
223230
print ("WFC L2err : %e eV/A" % avg[0])
224231

225232

226-
def test_polar (dp, args) :
233+
def test_polar (dp, args, global_polar = False) :
227234
if args.rand_seed is not None :
228235
np.random.seed(args.rand_seed % (2**32))
229236

230237
data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test)
231-
data.add('polarizability', 9, atomic=True, must=True, high_prec=False, type_sel = dp.get_sel_type())
238+
if not global_polar:
239+
data.add('polarizability', 9, atomic=True, must=True, high_prec=False, type_sel = dp.get_sel_type())
240+
else:
241+
data.add('polarizability', 9, atomic=False, must=True, high_prec=False, type_sel = dp.get_sel_type())
232242
test_data = data.get_test ()
233243
numb_test = args.numb_test
234244
natoms = len(test_data["type"][0])
@@ -239,12 +249,21 @@ def test_polar (dp, args) :
239249
box = test_data["box"][:numb_test]
240250
atype = test_data["type"][0]
241251
polar = dp.eval(coord, box, atype)
252+
sel_type = dp.get_sel_type()
253+
sel_natoms = 0
254+
for ii in sel_type:
255+
sel_natoms += sum(atype == ii)
242256

243257
polar = polar.reshape([numb_test,-1])
244258
l2f = (l2err (polar - test_data["polarizability"] [:numb_test]))
259+
l2fs = l2f/np.sqrt(sel_natoms)
260+
l2fa = l2f/sel_natoms
245261

246262
print ("# number of test data : %d " % numb_test)
247-
print ("Polarizability L2err : %e eV/A" % l2f)
263+
print ("Polarizability L2err : %e eV/A" % l2f)
264+
if global_polar:
265+
print ("Polarizability L2err/sqrtN : %e eV/A" % l2fs)
266+
print ("Polarizability L2err/N : %e eV/A" % l2fa)
248267

249268
detail_file = args.detail_file
250269
if detail_file is not None :

0 commit comments

Comments
 (0)