11import os
2+ import json
23
34
45def check_completeness (samples_dir ):
@@ -31,7 +32,7 @@ def check_completeness(samples_dir):
3132 )
3233 for model_path in samples_missing_meta :
3334 print (f" - { model_path } " )
34- print ()
35+
3536 return (
3637 len (samples_missing_hash ) == 0
3738 and len (samples_missing_json ) == 0
@@ -52,37 +53,77 @@ def check_redandancy(samples_dir):
5253 graph_hash2model_paths [graph_hash ].append (model_path )
5354
5455 has_duplicates = False
55- print (f"Totally { len (graph_hash2model_paths )} unique samples under { samples_dir } ." )
56+ print (f"Totally { len (graph_hash2model_paths )} unique graphs under { samples_dir } ." )
5657 for graph_hash , model_paths in graph_hash2model_paths .items ():
5758 graph_hash2model_paths [graph_hash ] = sorted (model_paths )
5859 if len (model_paths ) > 1 :
5960 has_duplicates = True
6061 print (f"Redundant models detected for grap_hash { graph_hash } :" )
6162 for model_path in model_paths :
6263 print (f" { model_path } " )
63-
6464 return has_duplicates , graph_hash2model_paths
6565
6666
67+ def count_samples (samples_dir , framework ):
68+ model_sources = os .listdir (samples_dir )
69+
70+ graph_net_count = 0
71+ graph_net_dict = {}
72+ model_names_set = set ()
73+ for source in model_sources :
74+ source_dir = os .path .join (samples_dir , source )
75+ if os .path .isdir (source_dir ):
76+ graph_net_dict [source ] = 0
77+ for root , dirs , files in os .walk (source_dir ):
78+ if "graph_net.json" in files :
79+ with open (os .path .join (root , "graph_net.json" ), "r" ) as f :
80+ data = json .load (f )
81+ model_name = data .get ("model_name" , None )
82+ if model_name is not None and model_name != "NO_VALID_MATCH_FOUND" :
83+ if model_name not in model_names_set :
84+ model_names_set .add (model_name )
85+ graph_net_count += 1
86+ graph_net_dict [source ] += 1
87+ else :
88+ graph_net_count += 1
89+ graph_net_dict [source ] += 1
90+
91+ print (f"Number of { framework } samples: { graph_net_count } " )
92+ for name , number in graph_net_dict .items ():
93+ print (f"- { name :24} : { number } " )
94+ print ()
95+
96+
6797def main ():
6898 filename = os .path .abspath (__file__ )
6999 root_dir = os .path .dirname (os .path .dirname (filename ))
70100
101+ framework2dirname = {
102+ "torch" : "samples" ,
103+ "paddle" : "paddle_samples" ,
104+ }
105+
71106 all_samples_completed = True
72- for samples_dirname in [ "samples" , "paddle_samples" ] :
107+ for samples_dirname in framework2dirname . values () :
73108 samples_dir = os .path .join (root_dir , samples_dirname )
74109 all_samples_completed = all_samples_completed and check_completeness (
75110 samples_dir
76111 )
112+ print ()
77113 assert all_samples_completed , "Please fix the incompleted samples!"
78114
79115 all_samples_has_duplicates = False
80- for samples_dirname in [ "samples" , "paddle_samples" ] :
116+ for samples_dirname in framework2dirname . values () :
81117 samples_dir = os .path .join (root_dir , samples_dirname )
82118 has_duplicates , graph_hash2model_paths = check_redandancy (samples_dir )
83119 all_samples_has_duplicates = all_samples_has_duplicates or has_duplicates
120+ print ()
84121 assert not all_samples_has_duplicates , "Please remove the redundant samples!"
85122
123+ for framework in framework2dirname .keys ():
124+ samples_dir = os .path .join (root_dir , framework2dirname [framework ])
125+ count_samples (samples_dir , framework )
126+
86127
87128if __name__ == "__main__" :
88129 main ()
0 commit comments