|
18 | 18 | def test (args): |
19 | 19 | de = DeepEval(args.model) |
20 | 20 | all_sys = expand_sys_str(args.system) |
| 21 | + if len(all_sys) == 0: |
| 22 | + print('Did not find valid system') |
21 | 23 | err_coll = [] |
22 | 24 | siz_coll = [] |
| 25 | + if de.model_type == 'ener': |
| 26 | + dp = DeepPot(args.model) |
| 27 | + elif de.model_type == 'dipole': |
| 28 | + dp = DeepDipole(args.model) |
| 29 | + elif de.model_type == 'polar': |
| 30 | + dp = DeepPolar(args.model) |
| 31 | + elif de.model_type == 'wfc': |
| 32 | + dp = DeepWFC(args.model) |
| 33 | + else : |
| 34 | + raise RuntimeError('unknow model type '+de.model_type) |
23 | 35 | for ii in all_sys: |
24 | 36 | args.system = ii |
25 | 37 | print ("# ---------------output of dp test--------------- ") |
26 | 38 | print ("# testing system : " + ii) |
27 | 39 | if de.model_type == 'ener': |
28 | | - err, siz = test_ener(args) |
| 40 | + err, siz = test_ener(dp, args) |
29 | 41 | elif de.model_type == 'dipole': |
30 | | - err, siz = test_dipole(args) |
| 42 | + err, siz = test_dipole(dp, args) |
31 | 43 | elif de.model_type == 'polar': |
32 | | - err, siz = test_polar(args) |
| 44 | + err, siz = test_polar(dp, args) |
33 | 45 | elif de.model_type == 'wfc': |
34 | | - err, siz = test_wfc(args) |
| 46 | + err, siz = test_wfc(dp, args) |
35 | 47 | else : |
36 | 48 | raise RuntimeError('unknow model type '+de.model_type) |
37 | 49 | print ("# ----------------------------------------------- ") |
38 | 50 | err_coll.append(err) |
39 | 51 | siz_coll.append(siz) |
40 | 52 | avg_err = weighted_average(err_coll, siz_coll) |
| 53 | + if len(all_sys) != len(err): |
| 54 | + print('Not all systems are tested! Check if the systems are valid') |
41 | 55 | if len(all_sys) > 1: |
42 | 56 | print ("# ----------weighted average of errors----------- ") |
43 | 57 | print ("# number of systems : %d" % len(all_sys)) |
@@ -75,11 +89,10 @@ def weighted_average(err_coll, siz_coll): |
75 | 89 | return sum_err |
76 | 90 |
|
77 | 91 |
|
78 | | -def test_ener (args) : |
| 92 | +def test_ener (dp, args) : |
79 | 93 | if args.rand_seed is not None : |
80 | 94 | np.random.seed(args.rand_seed % (2**32)) |
81 | 95 |
|
82 | | - dp = DeepPot(args.model) |
83 | 96 | data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test, type_map = dp.get_type_map()) |
84 | 97 | data.add('energy', 1, atomic=False, must=False, high_prec=True) |
85 | 98 | data.add('force', 3, atomic=True, must=False, high_prec=False) |
@@ -166,11 +179,10 @@ def print_ener_sys_avg(avg): |
166 | 179 | print ("Virial L2err/Natoms : %e eV" % avg[2]) |
167 | 180 |
|
168 | 181 |
|
169 | | -def test_wfc (args) : |
| 182 | +def test_wfc (dp, args) : |
170 | 183 | if args.rand_seed is not None : |
171 | 184 | np.random.seed(args.rand_seed % (2**32)) |
172 | 185 |
|
173 | | - dp = DeepWFC(args.model) |
174 | 186 | data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test) |
175 | 187 | data.add('wfc', 12, atomic=True, must=True, high_prec=False, type_sel = dp.get_sel_type()) |
176 | 188 | test_data = data.get_test () |
@@ -204,11 +216,10 @@ def print_wfc_sys_avg(avg): |
204 | 216 | print ("WFC L2err : %e eV/A" % avg[0]) |
205 | 217 |
|
206 | 218 |
|
207 | | -def test_polar (args) : |
| 219 | +def test_polar (dp, args) : |
208 | 220 | if args.rand_seed is not None : |
209 | 221 | np.random.seed(args.rand_seed % (2**32)) |
210 | 222 |
|
211 | | - dp = DeepPolar(args.model) |
212 | 223 | data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test) |
213 | 224 | data.add('polarizability', 9, atomic=True, must=True, high_prec=False, type_sel = dp.get_sel_type()) |
214 | 225 | test_data = data.get_test () |
@@ -242,11 +253,10 @@ def print_polar_sys_avg(avg): |
242 | 253 | print ("Polarizability L2err : %e eV/A" % avg[0]) |
243 | 254 |
|
244 | 255 |
|
245 | | -def test_dipole (args) : |
| 256 | +def test_dipole (dp, args) : |
246 | 257 | if args.rand_seed is not None : |
247 | 258 | np.random.seed(args.rand_seed % (2**32)) |
248 | 259 |
|
249 | | - dp = DeepDipole(args.model) |
250 | 260 | data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test) |
251 | 261 | data.add('dipole', 3, atomic=True, must=True, high_prec=False, type_sel = dp.get_sel_type()) |
252 | 262 | test_data = data.get_test () |
|
0 commit comments