@@ -46,23 +46,44 @@ def main(args):
4646 assert os .path .isdir (
4747 args .graph_net_samples_path
4848 ), f"args.graph_net_samples_path ({ args .graph_net_samples_path } ) is not a directory!"
49+
50+ current_model_graph_hash_pathes = set ()
51+ if args .model_path :
52+ assert os .path .isdir (
53+ args .model_path
54+ ), f"args.model_path { args .model_path } is not a directory!"
55+ current_model_graph_hash_pathes = set (
56+ graph_hash_path
57+ for model_path in get_recursively_model_pathes (args .model_path )
58+ for graph_hash_path in [f"{ model_path } /graph_hash.txt" ]
59+ )
60+
4961 find_redundant = False
5062 graph_hash2graph_net_model_path = {}
5163 for model_path in get_recursively_model_pathes (args .graph_net_samples_path ):
52- if args .model_path is None or args .model_path != model_path :
53- graph_hash_path = f"{ model_path } /graph_hash.txt"
54- if os .path .isfile (graph_hash_path ):
55- graph_hash = open (graph_hash_path ).read ()
56- if graph_hash not in graph_hash2graph_net_model_path .keys ():
57- graph_hash2graph_net_model_path [graph_hash ] = [graph_hash_path ]
58- else :
59- find_redundant = True
60- graph_hash2graph_net_model_path [graph_hash ].append (graph_hash_path )
64+ graph_hash_path = f"{ model_path } /graph_hash.txt"
65+ if (
66+ os .path .isfile (graph_hash_path )
67+ and graph_hash_path not in current_model_graph_hash_pathes
68+ ):
69+ graph_hash = open (graph_hash_path ).read ()
70+ if graph_hash not in graph_hash2graph_net_model_path .keys ():
71+ graph_hash2graph_net_model_path [graph_hash ] = [graph_hash_path ]
72+ else :
73+ find_redundant = True
74+ graph_hash2graph_net_model_path [graph_hash ].append (graph_hash_path )
6175 print (
6276 f"Totally { len (graph_hash2graph_net_model_path )} unique samples under { args .graph_net_samples_path } ."
6377 )
6478
6579 if not args .model_path :
80+ # Check whether the specified model is redundant.
81+ for current_model_graph_hash_path in current_model_graph_hash_pathes :
82+ graph_hash = open (current_model_graph_hash_path ).read ()
83+ assert (
84+ graph_hash not in graph_hash2graph_net_model_path
85+ ), f"Redundant models detected.\n \t graph_hash:{ graph_hash } , newly-added-model-path:{ current_model_graph_hash_path } , existing-model-path:{ graph_hash2graph_net_model_path [graph_hash ]} ."
86+ else :
6687 # Check whether there are redundant samples under samples directory.
6788 for graph_hash , graph_paths in graph_hash2graph_net_model_path .items ():
6889 if len (graph_paths ) > 1 :
@@ -72,21 +93,6 @@ def main(args):
7293 assert (
7394 not find_redundant
7495 ), f"Redundant models detected under { args .graph_net_samples_path } ."
75- else :
76- # Check whether the specified model is redundant.
77- assert os .path .isdir (
78- args .model_path
79- ), f"args.model_path { args .model_path } is not a directory!"
80- current_model_graph_hash_pathes = set (
81- graph_hash_path
82- for model_path in get_recursively_model_pathes (args .model_path )
83- for graph_hash_path in [f"{ model_path } /graph_hash.txt" ]
84- )
85- for current_model_graph_hash_path in current_model_graph_hash_pathes :
86- graph_hash = open (current_model_graph_hash_path ).read ()
87- assert (
88- graph_hash not in graph_hash2graph_net_model_path
89- ), f"Redundant models detected. old-model-path:{ current_model_graph_hash_path } , new-model-path:{ graph_hash2graph_net_model_path [graph_hash ]} ."
9096
9197
9298if __name__ == "__main__" :
0 commit comments