Skip to content

Commit 7b44d39

Browse files
authored
Merge pull request #245 from amcadmus/devel
fix bugs in dp test
2 parents 9912a79 + 128f86d commit 7b44d39

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

source/train/test.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ def test (args):
3232
dp = DeepWFC(args.model)
3333
else :
3434
raise RuntimeError('unknow model type '+de.model_type)
35-
for ii in all_sys:
35+
for cc,ii in enumerate(all_sys):
3636
args.system = ii
3737
print ("# ---------------output of dp test--------------- ")
3838
print ("# testing system : " + ii)
3939
if de.model_type == 'ener':
40-
err, siz = test_ener(dp, args)
40+
err, siz = test_ener(dp, args, append_detail = (cc!=0))
4141
elif de.model_type == 'dipole':
4242
err, siz = test_dipole(dp, args)
4343
elif de.model_type == 'polar':
@@ -89,7 +89,14 @@ def weighted_average(err_coll, siz_coll):
8989
return sum_err
9090

9191

92-
def test_ener (dp, args) :
92+
def save_txt_file(fname, data, header = "", append = False):
93+
fp = fname
94+
if append : fp = open(fp, 'ab')
95+
np.savetxt(fp, data, header = header)
96+
if append : fp.close()
97+
98+
99+
def test_ener (dp, args, append_detail = False) :
93100
if args.rand_seed is not None :
94101
np.random.seed(args.rand_seed % (2**32))
95102

@@ -122,10 +129,7 @@ def test_ener (dp, args) :
122129
else :
123130
aparam = None
124131
detail_file = args.detail_file
125-
if detail_file is not None:
126-
atomic = True
127-
else:
128-
atomic = False
132+
atomic = False
129133

130134
ret = dp.eval(coord, box, atype, fparam = fparam, aparam = aparam, atomic = atomic)
131135
energy = ret[0]
@@ -158,18 +162,21 @@ def test_ener (dp, args) :
158162
pe = np.concatenate((np.reshape(test_data["energy"][:numb_test], [-1,1]),
159163
np.reshape(energy, [-1,1])),
160164
axis = 1)
161-
np.savetxt(detail_file+".e.out", pe,
162-
header = 'data_e pred_e')
165+
save_txt_file(detail_file+".e.out", pe,
166+
header = '%s: data_e pred_e' % args.system,
167+
append = append_detail)
163168
pf = np.concatenate((np.reshape(test_data["force"] [:numb_test], [-1,3]),
164169
np.reshape(force, [-1,3])),
165170
axis = 1)
166-
np.savetxt(detail_file+".f.out", pf,
167-
header = 'data_fx data_fy data_fz pred_fx pred_fy pred_fz')
171+
save_txt_file(detail_file+".f.out", pf,
172+
header = '%s: data_fx data_fy data_fz pred_fx pred_fy pred_fz' % args.system,
173+
append = append_detail)
168174
pv = np.concatenate((np.reshape(test_data["virial"][:numb_test], [-1,9]),
169175
np.reshape(virial, [-1,9])),
170176
axis = 1)
171-
np.savetxt(detail_file+".v.out", pv,
172-
header = 'data_vxx data_vxy data_vxz data_vyx data_vyy data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz')
177+
save_txt_file(detail_file+".v.out", pv,
178+
header = '%s: data_vxx data_vxy data_vxz data_vyx data_vyy data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz' % args.system,
179+
append = append_detail)
173180
return [l2ea, l2f, l2va], [energy.size, force.size, virial.size]
174181

175182

0 commit comments

Comments
 (0)