forked from PaddlePaddle/GraphNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidate.py
More file actions
70 lines (63 loc) · 2.71 KB
/
validate.py
File metadata and controls
70 lines (63 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import os
import tempfile
import sys
import contextlib
import graph_net
@contextlib.contextmanager
def temp_workspace():
with tempfile.TemporaryDirectory() as tmp_dir_name:
old = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = tmp_dir_name
yield tmp_dir_name
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = old
def main(args):
model_path = args.model_path
with temp_workspace() as tmp_dir_name:
print("Check extractability ...")
cmd = f"{sys.executable} -m graph_net.torch.single_device_runner --model-path {model_path}"
cmd_ret = os.system(cmd)
assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
extract_name = "temp"
cmd = f"{sys.executable} -m graph_net.torch.single_device_runner --model-path {model_path} --enable-extract True --extract-name {extract_name} --dump-graph-hash-key"
cmd_ret = os.system(cmd)
assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
cmd = f"{sys.executable} -m graph_net.torch.single_device_runner --model-path {tmp_dir_name}/{extract_name}"
cmd_ret = os.system(cmd)
assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
if not args.no_check_redundancy:
print("Check redundancy ...")
graph_net_samples_path = (
(graph_net.torch.samples_util.get_default_samples_directory())
if args.graph_net_samples_path is None
else args.graph_net_samples_path
)
cmd = f"{sys.executable} -m graph_net.torch.check_redundant_incrementally --model-path {args.model_path} --graph-net-samples-path {graph_net_samples_path}"
cmd_ret = os.system(cmd)
rm_cmd = f"{sys.executable} -m graph_net.torch.remove_redundant_incrementally --model-path {args.model_path} --graph-net-samples-path {graph_net_samples_path}"
assert (
cmd_ret == 0
), f"\nPlease use the following command to remove redundant model directories:\n\n{rm_cmd}\n"
print(f"Validation success, {model_path=}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="load and run model")
parser.add_argument(
"--model-path",
type=str,
required=True,
help="Path to folder e.g '../../samples/torch/resnet18'",
)
parser.add_argument(
"--graph-net-samples-path",
type=str,
required=False,
default=None,
help="Path to GraphNet samples folder. e.g '../../samples'",
)
parser.add_argument(
"--no-check-redundancy",
action="store_true",
help="whether check model graph redundancy",
)
args = parser.parse_args()
main(args=args)