@@ -60,14 +60,20 @@ def _get_sub_ranges(self):
6060 ), f"Invalid range generated: start={ start_pos } , end={ end_pos } , max={ self .config ['max_nodes' ]} "
6161 yield start_pos , end_pos
6262
63- def _handle_success (self , temp_dir : str , start_pos : int , end_pos : int ) -> str :
64- target_name = f"_start{ start_pos } _end{ end_pos } "
63+ def _handle_success (
64+ self , temp_dir : str , start_pos : int , end_pos : int , model_name
65+ ) -> str :
66+ target_name = f"{ model_name } _start{ start_pos } _end{ end_pos } "
6567 target_path = os .path .join (
6668 self .config ["output_dir" ],
6769 target_name ,
6870 )
6971 os .makedirs (target_path , exist_ok = True )
70- shutil .move (temp_dir , target_path )
72+ # shutil.move(temp_dir, target_path)
73+ for item in os .listdir (temp_dir ):
74+ source = os .path .join (temp_dir , item )
75+ destination = os .path .join (target_path , item )
76+ shutil .move (source , destination )
7177 return target_path
7278
7379 def _build_decompose_config (
@@ -76,7 +82,7 @@ def _build_decompose_config(
7682 graph_net_root = os .path .dirname (graph_net .__file__ )
7783
7884 check_fusible_config = {
79- "handler_path" : f"{ graph_net_root } /torch/naive_graph_decomposer .py" ,
85+ "handler_path" : f"{ graph_net_root } /torch/graph_decomposer .py" ,
8086 "handler_class_name" : "NaiveDecomposerExtractor" ,
8187 "handler_config" : {
8288 "model_path_prefix" : model_path_prefix ,
@@ -105,7 +111,9 @@ def __call__(self, rel_model_path):
105111 )
106112 success = predicator (model_path )
107113 if success :
108- target_path = self ._handle_success (temp_dir , start_pos , end_pos )
114+ target_path = self ._handle_success (
115+ temp_dir , start_pos , end_pos , os .path .basename (model_path )
116+ )
109117 print (
110118 f"SUCCESS in finding the biggest fully fusible subgraph. Result saved to: { target_path } "
111119 )
0 commit comments