Skip to content

Commit 238504a

Browse files
committed
Update check_redundant implementation.
1 parent 94205f2 commit 238504a

File tree

1 file changed

+39
-22
lines changed

1 file changed

+39
-22
lines changed

graph_net/paddle/check_redundant_incrementally.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,42 +43,59 @@ def is_single_model_dir(model_dir):
4343

4444

4545
def main(args):
46-
assert os.path.isdir(args.model_path)
47-
assert os.path.isdir(args.graph_net_samples_path)
48-
current_model_graph_hash_pathes = set(
49-
graph_hash_path
50-
for model_path in get_recursively_model_pathes(args.model_path)
51-
for graph_hash_path in [f"{model_path}/graph_hash.txt"]
46+
assert os.path.isdir(
47+
args.graph_net_samples_path
48+
), f"args.graph_net_samples_path ({args.graph_net_samples_path}) is not a directory!"
49+
find_redundant = False
50+
graph_hash2graph_net_model_path = {}
51+
for model_path in get_recursively_model_pathes(args.graph_net_samples_path):
52+
graph_hash_path = f"{model_path}/graph_hash.txt"
53+
if os.path.isfile(graph_hash_path):
54+
graph_hash = open(graph_hash_path).read()
55+
if graph_hash not in graph_hash2graph_net_model_path.keys():
56+
graph_hash2graph_net_model_path[graph_hash] = graph_hash_path
57+
else:
58+
find_redundant = True
59+
print(
60+
f"Redundant models detected: {graph_hash2graph_net_model_path[graph_hash]} vs {graph_hash_path}"
61+
)
62+
print(
63+
f"Totally {len(graph_hash2graph_net_model_path)} unique samples under {args.graph_net_samples_path}."
5264
)
53-
graph_hash2graph_net_model_path = {
54-
graph_hash: graph_hash_path
55-
for model_path in get_recursively_model_pathes(args.graph_net_samples_path)
56-
for graph_hash_path in [f"{model_path}/graph_hash.txt"]
57-
if os.path.isfile(graph_hash_path)
58-
if graph_hash_path not in current_model_graph_hash_pathes
59-
for graph_hash in [open(graph_hash_path).read()]
60-
}
61-
for current_model_graph_hash_path in current_model_graph_hash_pathes:
62-
graph_hash = open(current_model_graph_hash_path).read()
63-
assert (
64-
graph_hash not in graph_hash2graph_net_model_path
65-
), f"Redundant models detected. old-model-path:{current_model_graph_hash_path}, new-model-path:{graph_hash2graph_net_model_path[graph_hash]}."
65+
assert (
66+
not find_redundant
67+
), f"Redundant models detected under {args.graph_net_samples_path}."
68+
69+
if args.model_path:
70+
assert os.path.isdir(
71+
args.model_path
72+
), f"args.model_path {args.model_path} is not a directory!"
73+
current_model_graph_hash_pathes = set(
74+
graph_hash_path
75+
for model_path in get_recursively_model_pathes(args.model_path)
76+
for graph_hash_path in [f"{model_path}/graph_hash.txt"]
77+
)
78+
for current_model_graph_hash_path in current_model_graph_hash_pathes:
79+
graph_hash = open(current_model_graph_hash_path).read()
80+
assert (
81+
graph_hash not in graph_hash2graph_net_model_path
82+
), f"Redundant models detected. old-model-path:{current_model_graph_hash_path}, new-model-path:{graph_hash2graph_net_model_path[graph_hash]}."
6683

6784

6885
if __name__ == "__main__":
6986
parser = argparse.ArgumentParser(description="Test compiler performance.")
7087
parser.add_argument(
7188
"--model-path",
7289
type=str,
73-
required=True,
90+
required=False,
7491
help="Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model",
7592
)
7693
parser.add_argument(
7794
"--graph-net-samples-path",
7895
type=str,
79-
required=False,
80-
default="default",
96+
required=True,
8197
help="Path to GraphNet samples",
8298
)
8399
args = parser.parse_args()
100+
print(args)
84101
main(args=args)

0 commit comments

Comments
 (0)