Skip to content

Commit 2a76ca4

Browse files
author
Han Wang
committed
append detail file when testing multi systems
1 parent 9912a79 commit 2a76ca4

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

source/train/test.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ def test (args):
3232
dp = DeepWFC(args.model)
3333
else :
3434
raise RuntimeError('unknow model type '+de.model_type)
35-
for ii in all_sys:
35+
for cc,ii in enumerate(all_sys):
3636
args.system = ii
3737
print ("# ---------------output of dp test--------------- ")
3838
print ("# testing system : " + ii)
3939
if de.model_type == 'ener':
40-
err, siz = test_ener(dp, args)
40+
err, siz = test_ener(dp, args, append_detail = (cc!=0))
4141
elif de.model_type == 'dipole':
4242
err, siz = test_dipole(dp, args)
4343
elif de.model_type == 'polar':
@@ -89,7 +89,14 @@ def weighted_average(err_coll, siz_coll):
8989
return sum_err
9090

9191

92-
def test_ener (dp, args) :
92+
def save_txt_file(fname, data, header = "", append = False):
93+
fp = fname
94+
if append : fp = open(fp, 'ab')
95+
np.savetxt(fp, data, header = header)
96+
if append : fp.close()
97+
98+
99+
def test_ener (dp, args, append_detail = False) :
93100
if args.rand_seed is not None :
94101
np.random.seed(args.rand_seed % (2**32))
95102

@@ -158,18 +165,21 @@ def test_ener (dp, args) :
158165
pe = np.concatenate((np.reshape(test_data["energy"][:numb_test], [-1,1]),
159166
np.reshape(energy, [-1,1])),
160167
axis = 1)
161-
np.savetxt(detail_file+".e.out", pe,
162-
header = 'data_e pred_e')
168+
save_txt_file(detail_file+".e.out", pe,
169+
header = '%s: data_e pred_e' % args.system,
170+
append = append_detail)
163171
pf = np.concatenate((np.reshape(test_data["force"] [:numb_test], [-1,3]),
164172
np.reshape(force, [-1,3])),
165173
axis = 1)
166-
np.savetxt(detail_file+".f.out", pf,
167-
header = 'data_fx data_fy data_fz pred_fx pred_fy pred_fz')
174+
save_txt_file(detail_file+".f.out", pf,
175+
header = '%s: data_fx data_fy data_fz pred_fx pred_fy pred_fz' % args.system,
176+
append = append_detail)
168177
pv = np.concatenate((np.reshape(test_data["virial"][:numb_test], [-1,9]),
169178
np.reshape(virial, [-1,9])),
170179
axis = 1)
171-
np.savetxt(detail_file+".v.out", pv,
172-
header = 'data_vxx data_vxy data_vxz data_vyx data_vyy data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz')
180+
save_txt_file(detail_file+".v.out", pv,
181+
header = '%s: data_vxx data_vxy data_vxz data_vyx data_vyy data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz' % args.system,
182+
append = append_detail)
173183
return [l2ea, l2f, l2va], [energy.size, force.size, virial.size]
174184

175185

0 commit comments

Comments
 (0)