11import os
2+ from pathlib import Path
23import graph_net
34import tempfile
45import shutil
5- from graph_net .torch import constraint_util
6+ from graph_net .torch import fully_fusible_graph_predicator
67from graph_net .torch .fx_graph_module_util import get_torch_module_and_inputs
78from graph_net .torch .fx_graph_parse_util import parse_sole_graph_module
89import logging
@@ -60,20 +61,16 @@ def _get_sub_ranges(self):
6061 ), f"Invalid range generated: start={ start_pos } , end={ end_pos } , max={ self .config ['max_nodes' ]} "
6162 yield start_pos , end_pos
6263
63- def _handle_success (
64- self , temp_dir : str , start_pos : int , end_pos : int , model_name
65- ) -> str :
66- target_name = f" { model_name } _start { start_pos } _end { end_pos } "
64+ def _handle_success (self , temp_dir : str , rel_model_path : str ) -> str :
65+ subdirs = list ( Path ( temp_dir ). iterdir ())
66+ assert len ( subdirs ) == 1
67+ temp_dir = str ( subdirs [ 0 ])
6768 target_path = os .path .join (
6869 self .config ["output_dir" ],
69- target_name ,
70+ rel_model_path ,
7071 )
7172 os .makedirs (target_path , exist_ok = True )
72- # shutil.move(temp_dir, target_path)
73- for item in os .listdir (temp_dir ):
74- source = os .path .join (temp_dir , item )
75- destination = os .path .join (target_path , item )
76- shutil .move (source , destination )
73+ shutil .copytree (temp_dir , target_path , dirs_exist_ok = True )
7774 return target_path
7875
7976 def _build_decompose_config (
@@ -90,7 +87,7 @@ def _build_decompose_config(
9087 "split_positions" : [start_pos , end_pos ],
9188 "group_head_and_tail" : False ,
9289 "post_extract_process_path" : f"{ graph_net_root } /torch/post_extract_process_count_kernels.py" ,
93- "post_extract_process_class_name" : "GraphFullyFusible " ,
90+ "post_extract_process_class_name" : "ThrowExitStatusIfGraphFullyFusible " ,
9491 },
9592 }
9693 return check_fusible_config
@@ -106,14 +103,14 @@ def __call__(self, rel_model_path):
106103 check_fusible_config = self ._build_decompose_config (
107104 temp_dir , start_pos , end_pos , self .config ["model_path_prefix" ]
108105 )
109- predicator = constraint_util . FusibleSubgraphPredicator (
106+ predicator = fully_fusible_graph_predicator . FullyFusibleGraphPredicator (
110107 check_fusible_config
111108 )
109+ logger .warning ("fully_fusible_graph_predicator-begin" )
112110 success = predicator (model_path )
111+ logger .warning ("fully_fusible_graph_predicator-end" )
113112 if success :
114- target_path = self ._handle_success (
115- temp_dir , start_pos , end_pos , os .path .basename (model_path )
116- )
113+ target_path = self ._handle_success (temp_dir , rel_model_path )
117114 print (
118115 f"SUCCESS in finding the biggest fully fusible subgraph. Result saved to: { target_path } "
119116 )
0 commit comments