Skip to content

Commit 14f2963

Browse files
committed
Merge count_sample into check_and_count_samples.
1 parent b2f86ad commit 14f2963

File tree

5 files changed

+131
-143
lines changed

5 files changed

+131
-143
lines changed

.github/workflows/Codestyle-Check.yml

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,4 @@ jobs:
5454
run: |
5555
set +e
5656
bash -x tools/codestyle/pre_commit.sh;EXCODE=$?
57-
exit $EXCODE
58-
59-
- name: Check samples
60-
if: steps.check-bypass.outputs.can-skip != 'true'
61-
run: |
62-
set +e
63-
python3.10 tools/check_samples.py;EXCODE=$?
64-
exit $EXCODE
65-
66-
- name: Count samples
67-
if: steps.check-bypass.outputs.can-skip != 'true'
68-
run: |
69-
set +e
70-
python3.10 tools/count_sample.py;EXCODE=$?
7157
exit $EXCODE

tools/check_and_count_samples.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import os
2+
import json
3+
4+
5+
def check_completeness(samples_dir):
6+
samples_missing_hash = []
7+
samples_missing_json = []
8+
samples_missing_meta = []
9+
for root, dirs, files in os.walk(samples_dir):
10+
if "shape_patches_" not in root and "model.py" in files:
11+
model_path = root
12+
if not os.path.exists(os.path.join(model_path, "graph_hash.txt")):
13+
samples_missing_hash.append(model_path)
14+
if not os.path.exists(os.path.join(model_path, "graph_net.json")):
15+
samples_missing_json.append(model_path)
16+
if not os.path.exists(
17+
os.path.join(model_path, "input_meta.py")
18+
) or not os.path.exists(os.path.join(model_path, "weight_meta.py")):
19+
samples_missing_meta.append(model_path)
20+
21+
all_samples_complete = (
22+
len(samples_missing_hash) == 0
23+
and len(samples_missing_json) == 0
24+
and len(samples_missing_meta) == 0
25+
)
26+
27+
if not all_samples_complete:
28+
print(f"Check completeness result for {samples_dir}:")
29+
print(f"1. {len(samples_missing_hash)} samples missing graph_hash.txt")
30+
for model_path in samples_missing_hash:
31+
print(f" - {model_path}")
32+
33+
print(f"2. {len(samples_missing_json)} samples missing graph_net.json")
34+
for model_path in samples_missing_json:
35+
print(f" - {model_path}")
36+
37+
print(
38+
f"3. {len(samples_missing_meta)} samples missing input_meta.py or weight_meta.py"
39+
)
40+
for model_path in samples_missing_meta:
41+
print(f" - {model_path}")
42+
print()
43+
44+
return all_samples_complete
45+
46+
47+
def check_redandancy(samples_dir):
48+
graph_hash2model_paths = {}
49+
for root, dirs, files in os.walk(samples_dir):
50+
if "graph_hash.txt" in files:
51+
model_path = root
52+
graph_hash_path = os.path.join(model_path, "graph_hash.txt")
53+
graph_hash = open(graph_hash_path).read()
54+
if graph_hash not in graph_hash2model_paths.keys():
55+
graph_hash2model_paths[graph_hash] = [model_path]
56+
else:
57+
graph_hash2model_paths[graph_hash].append(model_path)
58+
59+
has_duplicates = False
60+
print(f"Totally {len(graph_hash2model_paths)} unique graphs under {samples_dir}.")
61+
for graph_hash, model_paths in graph_hash2model_paths.items():
62+
graph_hash2model_paths[graph_hash] = sorted(model_paths)
63+
if len(model_paths) > 1:
64+
has_duplicates = True
65+
print(f"Redundant models detected for grap_hash {graph_hash}:")
66+
for model_path in model_paths:
67+
print(f" {model_path}")
68+
return has_duplicates, graph_hash2model_paths
69+
70+
71+
def count_samples(samples_dir, framework):
72+
model_sources = os.listdir(samples_dir)
73+
74+
graph_net_count = 0
75+
graph_net_dict = {}
76+
model_names_set = set()
77+
for source in model_sources:
78+
source_dir = os.path.join(samples_dir, source)
79+
if os.path.isdir(source_dir):
80+
graph_net_dict[source] = 0
81+
for root, dirs, files in os.walk(source_dir):
82+
if "graph_net.json" in files:
83+
with open(os.path.join(root, "graph_net.json"), "r") as f:
84+
data = json.load(f)
85+
model_name = data.get("model_name", None)
86+
if model_name is not None and model_name != "NO_VALID_MATCH_FOUND":
87+
if model_name not in model_names_set:
88+
model_names_set.add(model_name)
89+
graph_net_count += 1
90+
graph_net_dict[source] += 1
91+
else:
92+
graph_net_count += 1
93+
graph_net_dict[source] += 1
94+
95+
print(f"Number of {framework} samples: {graph_net_count}")
96+
for name, number in graph_net_dict.items():
97+
print(f"- {name:24}: {number}")
98+
print()
99+
100+
101+
def main():
102+
filename = os.path.abspath(__file__)
103+
root_dir = os.path.dirname(os.path.dirname(filename))
104+
105+
framework2dirname = {
106+
"torch": "samples",
107+
"paddle": "paddle_samples",
108+
}
109+
110+
all_samples_complete = True
111+
for samples_dirname in framework2dirname.values():
112+
samples_dir = os.path.join(root_dir, samples_dirname)
113+
all_samples_complete = all_samples_complete and check_completeness(samples_dir)
114+
assert all_samples_complete, "Please fix the incompleted samples!"
115+
116+
all_samples_has_duplicates = False
117+
for samples_dirname in framework2dirname.values():
118+
samples_dir = os.path.join(root_dir, samples_dirname)
119+
has_duplicates, graph_hash2model_paths = check_redandancy(samples_dir)
120+
all_samples_has_duplicates = all_samples_has_duplicates or has_duplicates
121+
print()
122+
assert not all_samples_has_duplicates, "Please remove the redundant samples!"
123+
124+
for framework in framework2dirname.keys():
125+
samples_dir = os.path.join(root_dir, framework2dirname[framework])
126+
count_samples(samples_dir, framework)
127+
128+
129+
if __name__ == "__main__":
130+
main()

tools/check_samples.py

Lines changed: 0 additions & 88 deletions
This file was deleted.

tools/ci/check_validate.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ function main() {
163163
check_validation_info=$(check_paddle_validation)
164164
check_validation_code=$?
165165
summary_problems $check_validation_code "$check_validation_info"
166-
python ${GRAPH_NET_EXTRACT_WORKSPACE}/tools/count_sample.py
166+
python ${GRAPH_NET_EXTRACT_WORKSPACE}/tools/check_and_count_samples.py
167167
LOG "[INFO] check_validation run success and no error!"
168168
}
169169

tools/count_sample.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

0 commit comments

Comments
 (0)