Skip to content

Commit eb78505

Browse files
committed
Add check completeness and redandancy of all samples on ci.
1 parent 357aaad commit eb78505

File tree

3 files changed

+132
-43
lines changed

3 files changed

+132
-43
lines changed

tools/check_and_count_samples.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import os
2+
import json
3+
4+
5+
def check_completeness(samples_dir):
6+
samples_missing_hash = []
7+
samples_missing_json = []
8+
samples_missing_meta = []
9+
for root, dirs, files in os.walk(samples_dir):
10+
if "shape_patches_" not in root and "model.py" in files:
11+
model_path = root
12+
if not os.path.exists(os.path.join(model_path, "graph_hash.txt")):
13+
samples_missing_hash.append(model_path)
14+
if not os.path.exists(os.path.join(model_path, "graph_net.json")):
15+
samples_missing_json.append(model_path)
16+
if not os.path.exists(
17+
os.path.join(model_path, "input_meta.py")
18+
) or not os.path.exists(os.path.join(model_path, "weight_meta.py")):
19+
samples_missing_meta.append(model_path)
20+
21+
all_samples_complete = (
22+
len(samples_missing_hash) == 0
23+
and len(samples_missing_json) == 0
24+
and len(samples_missing_meta) == 0
25+
)
26+
27+
if not all_samples_complete:
28+
print(f"Check completeness result for {samples_dir}:")
29+
print(f"1. {len(samples_missing_hash)} samples missing graph_hash.txt")
30+
for model_path in samples_missing_hash:
31+
print(f" - {model_path}")
32+
33+
print(f"2. {len(samples_missing_json)} samples missing graph_net.json")
34+
for model_path in samples_missing_json:
35+
print(f" - {model_path}")
36+
37+
print(
38+
f"3. {len(samples_missing_meta)} samples missing input_meta.py or weight_meta.py"
39+
)
40+
for model_path in samples_missing_meta:
41+
print(f" - {model_path}")
42+
print()
43+
44+
return all_samples_complete
45+
46+
47+
def check_redandancy(samples_dir):
48+
graph_hash2model_paths = {}
49+
for root, dirs, files in os.walk(samples_dir):
50+
if "graph_hash.txt" in files:
51+
model_path = root
52+
graph_hash_path = os.path.join(model_path, "graph_hash.txt")
53+
graph_hash = open(graph_hash_path).read()
54+
if graph_hash not in graph_hash2model_paths.keys():
55+
graph_hash2model_paths[graph_hash] = [model_path]
56+
else:
57+
graph_hash2model_paths[graph_hash].append(model_path)
58+
59+
has_duplicates = False
60+
print(f"Totally {len(graph_hash2model_paths)} unique graphs under {samples_dir}.")
61+
for graph_hash, model_paths in graph_hash2model_paths.items():
62+
graph_hash2model_paths[graph_hash] = sorted(model_paths)
63+
if len(model_paths) > 1:
64+
has_duplicates = True
65+
print(f"Redundant models detected for grap_hash {graph_hash}:")
66+
for model_path in model_paths:
67+
print(f" {model_path}")
68+
return has_duplicates, graph_hash2model_paths
69+
70+
71+
def count_samples(samples_dir, framework):
72+
model_sources = os.listdir(samples_dir)
73+
74+
graph_net_count = 0
75+
graph_net_dict = {}
76+
model_names_set = set()
77+
for source in model_sources:
78+
source_dir = os.path.join(samples_dir, source)
79+
if os.path.isdir(source_dir):
80+
graph_net_dict[source] = 0
81+
for root, dirs, files in os.walk(source_dir):
82+
if "graph_net.json" in files:
83+
with open(os.path.join(root, "graph_net.json"), "r") as f:
84+
data = json.load(f)
85+
model_name = data.get("model_name", None)
86+
if model_name is not None and model_name != "NO_VALID_MATCH_FOUND":
87+
if model_name not in model_names_set:
88+
model_names_set.add(model_name)
89+
graph_net_count += 1
90+
graph_net_dict[source] += 1
91+
else:
92+
graph_net_count += 1
93+
graph_net_dict[source] += 1
94+
95+
print(f"Number of {framework} samples: {graph_net_count}")
96+
for name, number in graph_net_dict.items():
97+
print(f"- {name:24}: {number}")
98+
print()
99+
100+
101+
def main():
102+
filename = os.path.abspath(__file__)
103+
root_dir = os.path.dirname(os.path.dirname(filename))
104+
105+
framework2dirname = {
106+
"torch": "samples",
107+
"paddle": "paddle_samples",
108+
}
109+
110+
all_samples_complete = True
111+
for samples_dirname in framework2dirname.values():
112+
samples_dir = os.path.join(root_dir, samples_dirname)
113+
all_samples_complete = all_samples_complete and check_completeness(samples_dir)
114+
assert all_samples_complete, "Please fix the incompleted samples!"
115+
116+
all_samples_has_duplicates = False
117+
for samples_dirname in framework2dirname.values():
118+
samples_dir = os.path.join(root_dir, samples_dirname)
119+
has_duplicates, graph_hash2model_paths = check_redandancy(samples_dir)
120+
all_samples_has_duplicates = all_samples_has_duplicates or has_duplicates
121+
print()
122+
assert not all_samples_has_duplicates, "Please remove the redundant samples!"
123+
124+
for framework in framework2dirname.keys():
125+
samples_dir = os.path.join(root_dir, framework2dirname[framework])
126+
count_samples(samples_dir, framework)
127+
128+
129+
if __name__ == "__main__":
130+
main()

tools/ci/check_validate.sh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ function prepare_torch_env() {
4141
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 > /dev/null
4242
[ $? -ne 0 ] && LOG "[FATAL] Install torch2.9.0 failed!" && exit -1
4343
else
44-
python ${GRAPH_NET_EXTRACT_WORKSPACE}/tools/count_sample.py
4544
LOG "[INFO] This pull request doesn't change any torch samples, skip the CI."
4645
fi
4746
}
@@ -62,7 +61,6 @@ function prepare_paddle_env() {
6261
[ $? -ne 0 ] && LOG "[FATAL] Install paddlepaddle-develop failed!" && exit -1
6362
python -c "import paddle; print('[PaddlePaddle Commit]', paddle.version.commit)"
6463
else
65-
python ${GRAPH_NET_EXTRACT_WORKSPACE}/tools/count_sample.py
6664
LOG "[INFO] This pull request doesn't change any paddle samples, skip the CI."
6765
fi
6866
}
@@ -165,7 +163,8 @@ function main() {
165163
check_validation_info=$(check_paddle_validation)
166164
check_validation_code=$?
167165
summary_problems $check_validation_code "$check_validation_info"
168-
python ${GRAPH_NET_EXTRACT_WORKSPACE}/tools/count_sample.py
166+
python ${GRAPH_NET_EXTRACT_WORKSPACE}/tools/check_and_count_samples.py >&2
167+
[ $? -ne 0 ] && LOG "[FATAL] Check completeness or redundancy failed!" && exit -1
169168
LOG "[INFO] check_validation run success and no error!"
170169
}
171170

tools/count_sample.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

0 commit comments

Comments
 (0)