Skip to content

Commit 6efe946

Browse files
authored
Merge pull request #186 from amcadmus/devel
compute averaged testing error for many systems
2 parents 77eac60 + 12fdcf4 commit 6efe946

File tree

1 file changed

+62
-4
lines changed

1 file changed

+62
-4
lines changed

source/train/test.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,37 @@
1818
def test (args):
1919
de = DeepEval(args.model)
2020
all_sys = expand_sys_str(args.system)
21+
err_coll = []
22+
siz_coll = []
2123
for ii in all_sys:
2224
args.system = ii
2325
print ("# ---------------output of dp test--------------- ")
2426
print ("# testing system : " + ii)
2527
if de.model_type == 'ener':
26-
test_ener(args)
28+
err, siz = test_ener(args)
2729
elif de.model_type == 'dipole':
28-
test_dipole(args)
30+
err, siz = test_dipole(args)
2931
elif de.model_type == 'polar':
30-
test_polar(args)
32+
err, siz = test_polar(args)
3133
elif de.model_type == 'wfc':
32-
test_wfc(args)
34+
err, siz = test_wfc(args)
35+
else :
36+
raise RuntimeError('unknow model type '+de.model_type)
37+
print ("# ----------------------------------------------- ")
38+
err_coll.append(err)
39+
siz_coll.append(siz)
40+
avg_err = weighted_average(err_coll, siz_coll)
41+
if len(all_sys) > 1:
42+
print ("# ----------weighted average of errors----------- ")
43+
print ("# number of systems : %d" % len(all_sys))
44+
if de.model_type == 'ener':
45+
print_ener_sys_avg(avg_err)
46+
elif de.model_type == 'dipole':
47+
print_dipole_sys_avg(avg_err)
48+
elif de.model_type == 'polar':
49+
print_polar_sys_avg(avg_err)
50+
elif de.model_type == 'wfc':
51+
print_wfc_sys_avg(avg_err)
3352
else :
3453
raise RuntimeError('unknow model type '+de.model_type)
3554
print ("# ----------------------------------------------- ")
@@ -39,6 +58,23 @@ def l2err (diff) :
3958
return np.sqrt(np.average (diff*diff))
4059

4160

61+
def weighted_average(err_coll, siz_coll):
62+
nsys = len(err_coll)
63+
nitems = len(err_coll[0])
64+
assert(len(err_coll) == len(siz_coll))
65+
sum_err = np.zeros(nitems)
66+
sum_siz = np.zeros(nitems)
67+
for sys_error, sys_size in zip(err_coll, siz_coll):
68+
for ii in range(nitems):
69+
ee = sys_error[ii]
70+
ss = sys_size [ii]
71+
sum_err[ii] += ee * ee * ss
72+
sum_siz[ii] += ss
73+
for ii in range(nitems):
74+
sum_err[ii] = np.sqrt(sum_err[ii] / sum_siz[ii])
75+
return sum_err
76+
77+
4278
def test_ener (args) :
4379
if args.rand_seed is not None :
4480
np.random.seed(args.rand_seed % (2**32))
@@ -121,6 +157,13 @@ def test_ener (args) :
121157
axis = 1)
122158
np.savetxt(detail_file+".v.out", pv,
123159
header = 'data_vxx data_vxy data_vxz data_vyx data_vyy data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz')
160+
return [l2ea, l2f, l2va], [energy.size, force.size, virial.size]
161+
162+
163+
def print_ener_sys_avg(avg):
164+
print ("Energy L2err/Natoms : %e eV" % avg[0])
165+
print ("Force L2err : %e eV/A" % avg[1])
166+
print ("Virial L2err/Natoms : %e eV" % avg[2])
124167

125168

126169
def test_wfc (args) :
@@ -154,6 +197,11 @@ def test_wfc (args) :
154197
axis = 1)
155198
np.savetxt(detail_file+".out", pe,
156199
header = 'ref_wfc(12 dofs) predicted_wfc(12 dofs)')
200+
return [l2f], [wfc.size]
201+
202+
203+
def print_wfc_sys_avg(avg):
204+
print ("WFC L2err : %e eV/A" % avg[0])
157205

158206

159207
def test_polar (args) :
@@ -187,6 +235,11 @@ def test_polar (args) :
187235
axis = 1)
188236
np.savetxt(detail_file+".out", pe,
189237
header = 'data_pxx data_pxy data_pxz data_pyx data_pyy data_pyz data_pzx data_pzy data_pzz pred_pxx pred_pxy pred_pxz pred_pyx pred_pyy pred_pyz pred_pzx pred_pzy pred_pzz')
238+
return [l2f], [polar.size]
239+
240+
241+
def print_polar_sys_avg(avg):
242+
print ("Polarizability L2err : %e eV/A" % avg[0])
190243

191244

192245
def test_dipole (args) :
@@ -220,3 +273,8 @@ def test_dipole (args) :
220273
axis = 1)
221274
np.savetxt(detail_file+".out", pe,
222275
header = 'data_x data_y data_z pred_x pred_y pred_z')
276+
return [l2f], [dipole.size]
277+
278+
279+
def print_dipole_sys_avg(avg):
280+
print ("Dipole L2err : %e eV/A" % avg[0])

0 commit comments

Comments
 (0)