Skip to content

Commit dd1bf21

Browse files
author
Han Wang
committed
solve the efficiency problem when testing multiple systems
1 parent 5c45e7f commit dd1bf21

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

source/train/test.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,40 @@
1818
def test (args):
1919
de = DeepEval(args.model)
2020
all_sys = expand_sys_str(args.system)
21+
if len(all_sys) == 0:
22+
print('Did not find valid system')
2123
err_coll = []
2224
siz_coll = []
25+
if de.model_type == 'ener':
26+
dp = DeepPot(args.model)
27+
elif de.model_type == 'dipole':
28+
dp = DeepDipole(args.model)
29+
elif de.model_type == 'polar':
30+
dp = DeepPolar(args.model)
31+
elif de.model_type == 'wfc':
32+
dp = DeepWFC(args.model)
33+
else :
34+
raise RuntimeError('unknow model type '+de.model_type)
2335
for ii in all_sys:
2436
args.system = ii
2537
print ("# ---------------output of dp test--------------- ")
2638
print ("# testing system : " + ii)
2739
if de.model_type == 'ener':
28-
err, siz = test_ener(args)
40+
err, siz = test_ener(dp, args)
2941
elif de.model_type == 'dipole':
30-
err, siz = test_dipole(args)
42+
err, siz = test_dipole(dp, args)
3143
elif de.model_type == 'polar':
32-
err, siz = test_polar(args)
44+
err, siz = test_polar(dp, args)
3345
elif de.model_type == 'wfc':
34-
err, siz = test_wfc(args)
46+
err, siz = test_wfc(dp, args)
3547
else :
3648
raise RuntimeError('unknow model type '+de.model_type)
3749
print ("# ----------------------------------------------- ")
3850
err_coll.append(err)
3951
siz_coll.append(siz)
4052
avg_err = weighted_average(err_coll, siz_coll)
53+
if len(all_sys) != len(err):
54+
print('Not all systems are tested! Check if the systems are valid')
4155
if len(all_sys) > 1:
4256
print ("# ----------weighted average of errors----------- ")
4357
print ("# number of systems : %d" % len(all_sys))
@@ -75,11 +89,10 @@ def weighted_average(err_coll, siz_coll):
7589
return sum_err
7690

7791

78-
def test_ener (args) :
92+
def test_ener (dp, args) :
7993
if args.rand_seed is not None :
8094
np.random.seed(args.rand_seed % (2**32))
8195

82-
dp = DeepPot(args.model)
8396
data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test, type_map = dp.get_type_map())
8497
data.add('energy', 1, atomic=False, must=False, high_prec=True)
8598
data.add('force', 3, atomic=True, must=False, high_prec=False)
@@ -166,11 +179,10 @@ def print_ener_sys_avg(avg):
166179
print ("Virial L2err/Natoms : %e eV" % avg[2])
167180

168181

169-
def test_wfc (args) :
182+
def test_wfc (dp, args) :
170183
if args.rand_seed is not None :
171184
np.random.seed(args.rand_seed % (2**32))
172185

173-
dp = DeepWFC(args.model)
174186
data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test)
175187
data.add('wfc', 12, atomic=True, must=True, high_prec=False, type_sel = dp.get_sel_type())
176188
test_data = data.get_test ()
@@ -204,11 +216,10 @@ def print_wfc_sys_avg(avg):
204216
print ("WFC L2err : %e eV/A" % avg[0])
205217

206218

207-
def test_polar (args) :
219+
def test_polar (dp, args) :
208220
if args.rand_seed is not None :
209221
np.random.seed(args.rand_seed % (2**32))
210222

211-
dp = DeepPolar(args.model)
212223
data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test)
213224
data.add('polarizability', 9, atomic=True, must=True, high_prec=False, type_sel = dp.get_sel_type())
214225
test_data = data.get_test ()
@@ -242,11 +253,10 @@ def print_polar_sys_avg(avg):
242253
print ("Polarizability L2err : %e eV/A" % avg[0])
243254

244255

245-
def test_dipole (args) :
256+
def test_dipole (dp, args) :
246257
if args.rand_seed is not None :
247258
np.random.seed(args.rand_seed % (2**32))
248259

249-
dp = DeepDipole(args.model)
250260
data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test)
251261
data.add('dipole', 3, atomic=True, must=True, high_prec=False, type_sel = dp.get_sel_type())
252262
test_data = data.get_test ()

0 commit comments

Comments
 (0)