Skip to content

Commit 7fea6e2

Browse files
committed
Merge count_sample into check_and_count_samples.
1 parent b2f86ad commit 7fea6e2

File tree

4 files changed

+47
-60
lines changed

4 files changed

+47
-60
lines changed

.github/workflows/Codestyle-Check.yml

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,4 @@ jobs:
5454
run: |
5555
set +e
5656
bash -x tools/codestyle/pre_commit.sh;EXCODE=$?
57-
exit $EXCODE
58-
59-
- name: Check samples
60-
if: steps.check-bypass.outputs.can-skip != 'true'
61-
run: |
62-
set +e
63-
python3.10 tools/check_samples.py;EXCODE=$?
64-
exit $EXCODE
65-
66-
- name: Count samples
67-
if: steps.check-bypass.outputs.can-skip != 'true'
68-
run: |
69-
set +e
70-
python3.10 tools/count_sample.py;EXCODE=$?
7157
exit $EXCODE
Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import json
23

34

45
def check_completeness(samples_dir):
@@ -31,7 +32,7 @@ def check_completeness(samples_dir):
3132
)
3233
for model_path in samples_missing_meta:
3334
print(f" - {model_path}")
34-
print()
35+
3536
return (
3637
len(samples_missing_hash) == 0
3738
and len(samples_missing_json) == 0
@@ -52,37 +53,77 @@ def check_redandancy(samples_dir):
5253
graph_hash2model_paths[graph_hash].append(model_path)
5354

5455
has_duplicates = False
55-
print(f"Totally {len(graph_hash2model_paths)} unique samples under {samples_dir}.")
56+
print(f"Totally {len(graph_hash2model_paths)} unique graphs under {samples_dir}.")
5657
for graph_hash, model_paths in graph_hash2model_paths.items():
5758
graph_hash2model_paths[graph_hash] = sorted(model_paths)
5859
if len(model_paths) > 1:
5960
has_duplicates = True
6061
print(f"Redundant models detected for grap_hash {graph_hash}:")
6162
for model_path in model_paths:
6263
print(f" {model_path}")
63-
6464
return has_duplicates, graph_hash2model_paths
6565

6666

67+
def count_samples(samples_dir, framework):
68+
model_sources = os.listdir(samples_dir)
69+
70+
graph_net_count = 0
71+
graph_net_dict = {}
72+
model_names_set = set()
73+
for source in model_sources:
74+
source_dir = os.path.join(samples_dir, source)
75+
if os.path.isdir(source_dir):
76+
graph_net_dict[source] = 0
77+
for root, dirs, files in os.walk(source_dir):
78+
if "graph_net.json" in files:
79+
with open(os.path.join(root, "graph_net.json"), "r") as f:
80+
data = json.load(f)
81+
model_name = data.get("model_name", None)
82+
if model_name is not None and model_name != "NO_VALID_MATCH_FOUND":
83+
if model_name not in model_names_set:
84+
model_names_set.add(model_name)
85+
graph_net_count += 1
86+
graph_net_dict[source] += 1
87+
else:
88+
graph_net_count += 1
89+
graph_net_dict[source] += 1
90+
91+
print(f"Number of {framework} samples: {graph_net_count}")
92+
for name, number in graph_net_dict.items():
93+
print(f"- {name:24}: {number}")
94+
print()
95+
96+
6797
def main():
6898
filename = os.path.abspath(__file__)
6999
root_dir = os.path.dirname(os.path.dirname(filename))
70100

101+
framework2dirname = {
102+
"torch": "samples",
103+
"paddle": "paddle_samples",
104+
}
105+
71106
all_samples_completed = True
72-
for samples_dirname in ["samples", "paddle_samples"]:
107+
for samples_dirname in framework2dirname.values():
73108
samples_dir = os.path.join(root_dir, samples_dirname)
74109
all_samples_completed = all_samples_completed and check_completeness(
75110
samples_dir
76111
)
112+
print()
77113
assert all_samples_completed, "Please fix the incompleted samples!"
78114

79115
all_samples_has_duplicates = False
80-
for samples_dirname in ["samples", "paddle_samples"]:
116+
for samples_dirname in framework2dirname.values():
81117
samples_dir = os.path.join(root_dir, samples_dirname)
82118
has_duplicates, graph_hash2model_paths = check_redandancy(samples_dir)
83119
all_samples_has_duplicates = all_samples_has_duplicates or has_duplicates
120+
print()
84121
assert not all_samples_has_duplicates, "Please remove the redundant samples!"
85122

123+
for framework in framework2dirname.keys():
124+
samples_dir = os.path.join(root_dir, framework2dirname[framework])
125+
count_samples(samples_dir, framework)
126+
86127

87128
if __name__ == "__main__":
88129
main()

tools/ci/check_validate.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ function main() {
163163
check_validation_info=$(check_paddle_validation)
164164
check_validation_code=$?
165165
summary_problems $check_validation_code "$check_validation_info"
166-
python ${GRAPH_NET_EXTRACT_WORKSPACE}/tools/count_sample.py
166+
python ${GRAPH_NET_EXTRACT_WORKSPACE}/tools/check_and_count_samples.py
167167
LOG "[INFO] check_validation run success and no error!"
168168
}
169169

tools/count_sample.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

0 commit comments

Comments
 (0)