Skip to content

Commit f78195d

Browse files
author
Han Wang
committed
recursively test systems
1 parent 2b98a4e commit f78195d

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

source/train/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def main () :
3737
parser_tst.add_argument("-m", "--model", default="frozen_model.pb", type=str,
3838
help="Frozen model file to import")
3939
parser_tst.add_argument("-s", "--system", default=".", type=str,
40-
help="The system dir")
40+
help="The system dir. Recursively detect systems in this directory")
4141
parser_tst.add_argument("-S", "--set-prefix", default="set", type=str,
4242
help="The set prefix")
4343
parser_tst.add_argument("-n", "--numb-test", default=100, type=int,

source/train/test.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,31 @@
1616

1717
def test (args):
1818
de = DeepEval(args.model)
19-
if de.model_type == 'ener':
20-
test_ener(args)
21-
elif de.model_type == 'dipole':
22-
test_dipole(args)
23-
elif de.model_type == 'polar':
24-
test_polar(args)
25-
elif de.model_type == 'wfc':
26-
test_wfc(args)
27-
else :
28-
raise RuntimeError('unknow model type '+de.model_type)
19+
all_sys = []
20+
from pathlib import Path
21+
for filename in Path(args.system).rglob('type.raw'):
22+
all_sys.append(os.path.dirname(filename))
23+
for ii in all_sys:
24+
args.system = ii
25+
print ("# ---------------output of dp test--------------- ")
26+
print ("# testing system : " + ii)
27+
if de.model_type == 'ener':
28+
test_ener(args)
29+
elif de.model_type == 'dipole':
30+
test_dipole(args)
31+
elif de.model_type == 'polar':
32+
test_polar(args)
33+
elif de.model_type == 'wfc':
34+
test_wfc(args)
35+
else :
36+
raise RuntimeError('unknow model type '+de.model_type)
37+
print ("# ----------------------------------------------- ")
38+
2939

3040
def l2err (diff) :
3141
return np.sqrt(np.average (diff*diff))
3242

43+
3344
def test_ener (args) :
3445
if args.rand_seed is not None :
3546
np.random.seed(args.rand_seed % (2**32))

0 commit comments

Comments
 (0)