Skip to content

Commit b2f86ad

Browse files
committed
Add check completeness of samples on ci.
1 parent 357aaad commit b2f86ad

File tree

7 files changed

+112
-5
lines changed

7 files changed

+112
-5
lines changed

.github/workflows/Codestyle-Check.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,17 @@ jobs:
5555
set +e
5656
bash -x tools/codestyle/pre_commit.sh;EXCODE=$?
5757
exit $EXCODE
58+
59+
- name: Check samples
60+
if: steps.check-bypass.outputs.can-skip != 'true'
61+
run: |
62+
set +e
63+
python3.10 tools/check_samples.py;EXCODE=$?
64+
exit $EXCODE
65+
66+
- name: Count samples
67+
if: steps.check-bypass.outputs.can-skip != 'true'
68+
run: |
69+
set +e
70+
python3.10 tools/count_sample.py;EXCODE=$?
71+
exit $EXCODE

samples/transformers-auto-model/Qwen1.5-0.5B/graph_net.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
"num_nodes_required": 1,
55
"dynamic": false,
66
"model_name": "Qwen/Qwen1.5-0.5B",
7-
"heuristic_tag": "unknown"
8-
}
7+
"heuristic_tag": "nlp"
8+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"source": "huggingface_hub",
6+
"heuristic_tag": "nlp"
7+
}

samples/transformers-auto-model/joeddav_xlm-roberta-large-xnli/input_meta.py

Whitespace-only changes.

tools/check_samples.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
3+
4+
def check_completeness(samples_dir):
5+
samples_missing_hash = []
6+
samples_missing_json = []
7+
samples_missing_meta = []
8+
for root, dirs, files in os.walk(samples_dir):
9+
if "shape_patches_" not in root and "model.py" in files:
10+
model_path = root
11+
if not os.path.exists(os.path.join(model_path, "graph_hash.txt")):
12+
samples_missing_hash.append(model_path)
13+
if not os.path.exists(os.path.join(model_path, "graph_net.json")):
14+
samples_missing_json.append(model_path)
15+
if not os.path.exists(
16+
os.path.join(model_path, "input_meta.py")
17+
) or not os.path.exists(os.path.join(model_path, "weight_meta.py")):
18+
samples_missing_meta.append(model_path)
19+
20+
print(f"Check completeness result for {samples_dir}:")
21+
print(f"1. {len(samples_missing_hash)} samples missing graph_hash.txt")
22+
for model_path in samples_missing_hash:
23+
print(f" - {model_path}")
24+
25+
print(f"2. {len(samples_missing_json)} samples missing graph_net.json")
26+
for model_path in samples_missing_json:
27+
print(f" - {model_path}")
28+
29+
print(
30+
f"3. {len(samples_missing_meta)} samples missing input_meta.py or weight_meta.py"
31+
)
32+
for model_path in samples_missing_meta:
33+
print(f" - {model_path}")
34+
print()
35+
return (
36+
len(samples_missing_hash) == 0
37+
and len(samples_missing_json) == 0
38+
and len(samples_missing_meta) == 0
39+
)
40+
41+
42+
def check_redandancy(samples_dir):
43+
graph_hash2model_paths = {}
44+
for root, dirs, files in os.walk(samples_dir):
45+
if "graph_hash.txt" in files:
46+
model_path = root
47+
graph_hash_path = os.path.join(model_path, "graph_hash.txt")
48+
graph_hash = open(graph_hash_path).read()
49+
if graph_hash not in graph_hash2model_paths.keys():
50+
graph_hash2model_paths[graph_hash] = [model_path]
51+
else:
52+
graph_hash2model_paths[graph_hash].append(model_path)
53+
54+
has_duplicates = False
55+
print(f"Totally {len(graph_hash2model_paths)} unique samples under {samples_dir}.")
56+
for graph_hash, model_paths in graph_hash2model_paths.items():
57+
graph_hash2model_paths[graph_hash] = sorted(model_paths)
58+
if len(model_paths) > 1:
59+
has_duplicates = True
60+
print(f"Redundant models detected for grap_hash {graph_hash}:")
61+
for model_path in model_paths:
62+
print(f" {model_path}")
63+
64+
return has_duplicates, graph_hash2model_paths
65+
66+
67+
def main():
68+
filename = os.path.abspath(__file__)
69+
root_dir = os.path.dirname(os.path.dirname(filename))
70+
71+
all_samples_completed = True
72+
for samples_dirname in ["samples", "paddle_samples"]:
73+
samples_dir = os.path.join(root_dir, samples_dirname)
74+
all_samples_completed = all_samples_completed and check_completeness(
75+
samples_dir
76+
)
77+
assert all_samples_completed, "Please fix the incompleted samples!"
78+
79+
all_samples_has_duplicates = False
80+
for samples_dirname in ["samples", "paddle_samples"]:
81+
samples_dir = os.path.join(root_dir, samples_dirname)
82+
has_duplicates, graph_hash2model_paths = check_redandancy(samples_dir)
83+
all_samples_has_duplicates = all_samples_has_duplicates or has_duplicates
84+
assert not all_samples_has_duplicates, "Please remove the redundant samples!"
85+
86+
87+
if __name__ == "__main__":
88+
main()

tools/ci/check_validate.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ function prepare_torch_env() {
4141
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 > /dev/null
4242
[ $? -ne 0 ] && LOG "[FATAL] Install torch2.9.0 failed!" && exit -1
4343
else
44-
python ${GRAPH_NET_EXTRACT_WORKSPACE}/tools/count_sample.py
4544
LOG "[INFO] This pull request doesn't change any torch samples, skip the CI."
4645
fi
4746
}
@@ -62,7 +61,6 @@ function prepare_paddle_env() {
6261
[ $? -ne 0 ] && LOG "[FATAL] Install paddlepaddle-develop failed!" && exit -1
6362
python -c "import paddle; print('[PaddlePaddle Commit]', paddle.version.commit)"
6463
else
65-
python ${GRAPH_NET_EXTRACT_WORKSPACE}/tools/count_sample.py
6664
LOG "[INFO] This pull request doesn't change any paddle samples, skip the CI."
6765
fi
6866
}

tools/count_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
with open(os.path.join(root, "graph_net.json"), "r") as f:
2626
data = json.load(f)
2727
model_name = data.get("model_name", None)
28-
if model_name is not None:
28+
if model_name is not None and model_name != "NO_VALID_MATCH_FOUND":
2929
if model_name not in model_names_set:
3030
model_names_set.add(model_name)
3131
graph_net_count += 1

0 commit comments

Comments
 (0)