11import os
22import torch
33import copy
4- import random
54from graph_net .torch .decompose_util import convert_to_submodules_graph
65from graph_net .torch .extractor import GraphExtractor as BuiltinGraphExtractor
76import graph_net .imp_util as imp_util
87
98
10- def generate_split_positions (max_pos = 32 , max_splits = 8 ):
11- num_splits = random .randint (3 , max_splits )
12- positions = random .sample (range (1 , max_pos ), num_splits )
13- positions .sort ()
14- return positions
15-
16-
179class GraphExtractor :
1810 def __init__ (
1911 self ,
@@ -30,6 +22,7 @@ def __init__(
3022 self .placeholder_auto_rename = placeholder_auto_rename
3123 self .config = self .make_config (** config )
3224 self .last_post_process_result = False
25+ self .decompose_finished = False
3326
3427 def make_config (
3528 self ,
@@ -41,6 +34,7 @@ def make_config(
4134 filter_config = None ,
4235 post_extract_process_path = None ,
4336 post_extract_process_class_name = None ,
37+ max_step = 8 ,
4438 ):
4539 for pos in split_positions :
4640 assert isinstance (
@@ -55,45 +49,46 @@ def make_config(
5549 "filter_config" : filter_config if filter_config is not None else {},
5650 "post_extract_process_path" : post_extract_process_path ,
5751 "post_extract_process_class_name" : post_extract_process_class_name ,
52+ "max_step" : max_step ,
5853 }
5954
6055 def __call__ (self , gm : torch .fx .GraphModule , sample_inputs ):
61- max_retries = 20
62- for i in range (max_retries ):
63- print (f"--- Attempt { i + 1 } ---" )
64- self .last_post_process_result = False
65- config = {
66- k : v
67- for k , v in self .config .items ()
68- if k in {"split_positions" , "group_head_and_tail" , "chain_style" }
69- }
70- print (f"Current Config: { config ['split_positions' ]} " )
71-
72- gm_to_process = copy .deepcopy (gm )
73-
74- rewrited_gm = convert_to_submodules_graph (
75- gm_to_process ,
76- submodule_hook = self .get_naive_decomposer_extractor ,
77- ** config ,
78- )
79-
80- try :
81- rewrited_gm (* sample_inputs )
82- except Exception as e :
83- print (f"Run failed: { e } " )
56+ for i in range (self .config ["max_step" ], - 1 , - 1 ):
57+ start_pos = 0
58+ for start_pos in range (32 - i ):
59+ end_pos = start_pos + i
60+ self .config ["split_positions" ] = [start_pos , end_pos ]
61+ torch ._dynamo .reset ()
8462 self .last_post_process_result = False
63+ config = {
64+ k : v
65+ for k , v in self .config .items ()
66+ if k in {"split_positions" , "group_head_and_tail" , "chain_style" }
67+ }
68+ print (f"Current Config: { config ['split_positions' ]} " )
69+ gm_to_process = copy .deepcopy (gm )
70+ rewrited_gm = convert_to_submodules_graph (
71+ gm_to_process ,
72+ submodule_hook = self .get_naive_decomposer_extractor ,
73+ ** config ,
74+ )
75+ try :
76+ rewrited_gm (* sample_inputs )
77+ except Exception as e :
78+ print (f"Run failed: { e } " )
79+ self .last_post_process_result = False
80+ if self .last_post_process_result and self .decompose_finished :
81+ print ("Success! Subgraph is fully fusionable." )
82+ break
8583 if self .last_post_process_result :
86- print ("Success! Subgraph is fully fusionable." )
8784 break
88- else :
89- print ("Failed. Generating new split positions..." )
90- self .config ["split_positions" ] = generate_split_positions ()
91-
92- if i == max_retries - 1 :
93- print ("failed to find a fully fusionable subgraph" )
9485 return rewrited_gm
9586
96- def get_naive_decomposer_extractor (self , submodule , seq_no ):
87+ def get_naive_decomposer_extractor (
88+ self ,
89+ submodule ,
90+ seq_no ,
91+ ):
9792 return NaiveDecomposerExtractor (self , submodule , seq_no )
9893
9994
@@ -119,14 +114,16 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
119114 )
120115
121116 def forward (self , * args ):
122- print ("forward" )
123117 if not self .extracted :
124118 if self .need_extract (self .submodule , args ):
125119 self .builtin_extractor (self .submodule , args )
126- success = self ._post_extract_process ()
127- if success :
128- print (f"Submodule { self .seq_no } failed fusion check." )
120+ if self ._post_extract_process () and self .seq_no == 1 :
129121 self .parent_graph_extractor .last_post_process_result = True
122+ print ("biggest fully fusionable subgraph found!!" , self .model_name )
123+ if self .seq_no == len (
124+ self .parent_graph_extractor .config ["split_positions" ]
125+ ):
126+ self .parent_graph_extractor .decompose_finished = True
130127 self .extracted = True
131128 return self .submodule (* args )
132129
0 commit comments