@@ -43,42 +43,59 @@ def is_single_model_dir(model_dir):
4343
4444
4545def main (args ):
46- assert os .path .isdir (args .model_path )
47- assert os .path .isdir (args .graph_net_samples_path )
48- current_model_graph_hash_pathes = set (
49- graph_hash_path
50- for model_path in get_recursively_model_pathes (args .model_path )
51- for graph_hash_path in [f"{ model_path } /graph_hash.txt" ]
46+ assert os .path .isdir (
47+ args .graph_net_samples_path
48+ ), f"args.graph_net_samples_path ({ args .graph_net_samples_path } ) is not a directory!"
49+ find_redundant = False
50+ graph_hash2graph_net_model_path = {}
51+ for model_path in get_recursively_model_pathes (args .graph_net_samples_path ):
52+ graph_hash_path = f"{ model_path } /graph_hash.txt"
53+ if os .path .isfile (graph_hash_path ):
54+ graph_hash = open (graph_hash_path ).read ()
55+ if graph_hash not in graph_hash2graph_net_model_path .keys ():
56+ graph_hash2graph_net_model_path [graph_hash ] = graph_hash_path
57+ else :
58+ find_redundant = True
59+ print (
60+ f"Redundant models detected: { graph_hash2graph_net_model_path [graph_hash ]} vs { graph_hash_path } "
61+ )
62+ print (
63+ f"Totally { len (graph_hash2graph_net_model_path )} unique samples under { args .graph_net_samples_path } ."
5264 )
53- graph_hash2graph_net_model_path = {
54- graph_hash : graph_hash_path
55- for model_path in get_recursively_model_pathes (args .graph_net_samples_path )
56- for graph_hash_path in [f"{ model_path } /graph_hash.txt" ]
57- if os .path .isfile (graph_hash_path )
58- if graph_hash_path not in current_model_graph_hash_pathes
59- for graph_hash in [open (graph_hash_path ).read ()]
60- }
61- for current_model_graph_hash_path in current_model_graph_hash_pathes :
62- graph_hash = open (current_model_graph_hash_path ).read ()
63- assert (
64- graph_hash not in graph_hash2graph_net_model_path
65- ), f"Redundant models detected. old-model-path:{ current_model_graph_hash_path } , new-model-path:{ graph_hash2graph_net_model_path [graph_hash ]} ."
65+ assert (
66+ not find_redundant
67+ ), f"Redundant models detected under { args .graph_net_samples_path } ."
68+
69+ if args .model_path :
70+ assert os .path .isdir (
71+ args .model_path
72+ ), f"args.model_path { args .model_path } is not a directory!"
73+ current_model_graph_hash_pathes = set (
74+ graph_hash_path
75+ for model_path in get_recursively_model_pathes (args .model_path )
76+ for graph_hash_path in [f"{ model_path } /graph_hash.txt" ]
77+ )
78+ for current_model_graph_hash_path in current_model_graph_hash_pathes :
79+ graph_hash = open (current_model_graph_hash_path ).read ()
80+ assert (
81+ graph_hash not in graph_hash2graph_net_model_path
82+ ), f"Redundant models detected. old-model-path:{ current_model_graph_hash_path } , new-model-path:{ graph_hash2graph_net_model_path [graph_hash ]} ."
6683
6784
6885if __name__ == "__main__" :
6986 parser = argparse .ArgumentParser (description = "Test compiler performance." )
7087 parser .add_argument (
7188 "--model-path" ,
7289 type = str ,
73- required = True ,
90+ required = False ,
7491 help = "Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model" ,
7592 )
7693 parser .add_argument (
7794 "--graph-net-samples-path" ,
7895 type = str ,
79- required = False ,
80- default = "default" ,
96+ required = True ,
8197 help = "Path to GraphNet samples" ,
8298 )
8399 args = parser .parse_args ()
100+ print (args )
84101 main (args = args )
0 commit comments