@@ -43,42 +43,62 @@ def is_single_model_dir(model_dir):
4343
4444
4545def main (args ):
46- assert os .path .isdir (args .model_path )
47- assert os .path .isdir (args .graph_net_samples_path )
48- current_model_graph_hash_pathes = set (
49- graph_hash_path
50- for model_path in get_recursively_model_pathes (args .model_path )
51- for graph_hash_path in [f"{ model_path } /graph_hash.txt" ]
46+ assert os .path .isdir (
47+ args .graph_net_samples_path
48+ ), f"args.graph_net_samples_path ({ args .graph_net_samples_path } ) is not a directory!"
49+ find_redundant = False
50+ graph_hash2graph_net_model_path = {}
51+ for model_path in get_recursively_model_pathes (args .graph_net_samples_path ):
52+ graph_hash_path = f"{ model_path } /graph_hash.txt"
53+ if os .path .isfile (graph_hash_path ):
54+ graph_hash = open (graph_hash_path ).read ()
55+ if graph_hash not in graph_hash2graph_net_model_path .keys ():
56+ graph_hash2graph_net_model_path [graph_hash ] = [graph_hash_path ]
57+ else :
58+ find_redundant = True
59+ graph_hash2graph_net_model_path [graph_hash ].append (graph_hash_path )
60+ print (
61+ f"Totally { len (graph_hash2graph_net_model_path )} unique samples under { args .graph_net_samples_path } ."
5262 )
53- graph_hash2graph_net_model_path = {
54- graph_hash : graph_hash_path
55- for model_path in get_recursively_model_pathes (args .graph_net_samples_path )
56- for graph_hash_path in [f"{ model_path } /graph_hash.txt" ]
57- if os .path .isfile (graph_hash_path )
58- if graph_hash_path not in current_model_graph_hash_pathes
59- for graph_hash in [open (graph_hash_path ).read ()]
60- }
61- for current_model_graph_hash_path in current_model_graph_hash_pathes :
62- graph_hash = open (current_model_graph_hash_path ).read ()
63- assert (
64- graph_hash not in graph_hash2graph_net_model_path
65- ), f"Redundant models detected. old-model-path:{ current_model_graph_hash_path } , new-model-path:{ graph_hash2graph_net_model_path [graph_hash ]} ."
63+ for graph_hash , graph_paths in graph_hash2graph_net_model_path .items ():
64+ if len (graph_paths ) > 1 :
65+ print (f"Redundant models detected for grap_hash { graph_hash } :" )
66+ for model_path in graph_paths :
67+ print (f" { model_path } " )
68+ assert (
69+ not find_redundant
70+ ), f"Redundant models detected under { args .graph_net_samples_path } ."
71+
72+ if args .model_path :
73+ assert os .path .isdir (
74+ args .model_path
75+ ), f"args.model_path { args .model_path } is not a directory!"
76+ current_model_graph_hash_pathes = set (
77+ graph_hash_path
78+ for model_path in get_recursively_model_pathes (args .model_path )
79+ for graph_hash_path in [f"{ model_path } /graph_hash.txt" ]
80+ )
81+ for current_model_graph_hash_path in current_model_graph_hash_pathes :
82+ graph_hash = open (current_model_graph_hash_path ).read ()
83+ assert (
84+ graph_hash not in graph_hash2graph_net_model_path
85+ ), f"Redundant models detected. old-model-path:{ current_model_graph_hash_path } , new-model-path:{ graph_hash2graph_net_model_path [graph_hash ]} ."
6686
6787
6888if __name__ == "__main__" :
6989 parser = argparse .ArgumentParser (description = "Test compiler performance." )
7090 parser .add_argument (
7191 "--model-path" ,
7292 type = str ,
73- required = True ,
93+ required = False ,
7494 help = "Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model" ,
7595 )
7696 parser .add_argument (
7797 "--graph-net-samples-path" ,
7898 type = str ,
79- required = False ,
80- default = "default" ,
99+ required = True ,
81100 help = "Path to GraphNet samples" ,
82101 )
83102 args = parser .parse_args ()
103+ print (args )
84104 main (args = args )
0 commit comments