33from graph_net .torch .decompose_util import convert_to_submodules_graph
44from graph_net .torch .extractor import GraphExtractor as BuiltinGraphExtractor
55import graph_net .imp_util as imp_util
6+ from graph_net .torch .fx_graph_module_util import get_torch_module_and_inputs
7+ from graph_net .torch .fx_graph_parse_util import parse_sole_graph_module
68
79
810class GraphExtractor :
11+ """
12+ Used by graph_net.torch.run_model
13+ """
14+
915 def __init__ (
1016 self ,
1117 config : dict ,
@@ -66,29 +72,109 @@ def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
6672 return rewrited_gm
6773
6874 def get_naive_decomposer_extractor (self , submodule , seq_no ):
69- return NaiveDecomposerExtractor (self , submodule , seq_no )
75+ return NaiveDecomposerExtractorModule (
76+ config = self .config ,
77+ parent_graph_name = self .name ,
78+ submodule = submodule ,
79+ seq_no = seq_no ,
80+ )
81+
82+
83+ class NaiveDecomposerExtractor :
84+ """
85+ Used by graph_net.model_path_handler
86+ """
87+
88+ def __init__ (self , config : dict = None ):
89+ if config is None :
90+ config = {}
91+ self .config = self ._make_config (** config )
92+
93+ def _make_config (
94+ self ,
95+ split_positions = (),
96+ group_head_and_tail = False ,
97+ chain_style = False ,
98+ output_dir = "./tmp/naive_decomposer_dir" ,
99+ filter_path = None ,
100+ filter_config = None ,
101+ post_extract_process_path = None ,
102+ post_extract_process_class_name = None ,
103+ post_extract_process_config = None ,
104+ ** kwargs ,
105+ ):
106+ if post_extract_process_config is None :
107+ post_extract_process_config = {}
108+ for pos in split_positions :
109+ assert isinstance (
110+ pos , int
111+ ), f"split_positions should be list of int, { split_positions = } "
112+ return {
113+ "split_positions" : split_positions ,
114+ "group_head_and_tail" : group_head_and_tail ,
115+ "chain_style" : chain_style ,
116+ "output_dir" : output_dir ,
117+ "filter_path" : filter_path ,
118+ "filter_config" : filter_config if filter_config is not None else {},
119+ "post_extract_process_path" : post_extract_process_path ,
120+ "post_extract_process_class_name" : post_extract_process_class_name ,
121+ "post_extract_process_config" : post_extract_process_config ,
122+ }
123+
124+ def __call__ (self , model_path ):
125+ config = {
126+ k : v
127+ for k , v in self .config .items ()
128+ if k in {"split_positions" , "group_head_and_tail" , "chain_style" }
129+ }
130+ module , inputs = get_torch_module_and_inputs (model_path )
131+ gm = parse_sole_graph_module (module , inputs )
132+ rewrited_gm : torch .fx .GraphModule = convert_to_submodules_graph (
133+ gm ,
134+ submodule_hook = self .get_naive_decomposer_extractor (model_path ),
135+ ** config ,
136+ )
137+ rewrited_gm (* inputs )
138+
139+ def get_naive_decomposer_extractor (self , model_path ):
140+ def fn (submodule , seq_no ):
141+ return NaiveDecomposerExtractorModule (
142+ config = self .config ,
143+ parent_graph_name = os .path .basename (model_path ),
144+ submodule = submodule ,
145+ seq_no = seq_no ,
146+ )
147+
148+ return fn
70149
71150
72- class NaiveDecomposerExtractor (torch .nn .Module ):
73- def __init__ (self , parent_graph_extractor , submodule , seq_no ):
151+ class NaiveDecomposerExtractorModule (torch .nn .Module ):
152+ def __init__ (
153+ self ,
154+ config : dict ,
155+ parent_graph_name : str ,
156+ submodule : torch .nn .Module ,
157+ seq_no : int ,
158+ ):
74159 super ().__init__ ()
75- self .parent_graph_extractor = parent_graph_extractor
160+ self .config = config
76161 self .submodule = submodule
77162 self .seq_no = seq_no
78163 self .extracted = False
79- name = f"{ parent_graph_extractor .name } _{ self .seq_no } "
80- self .model_name = name
164+ if self .seq_no is None :
165+ self .model_name = parent_graph_name
166+ else :
167+ submodule_name = f"{ parent_graph_name } _{ self .seq_no } "
168+ self .model_name = submodule_name
81169 self .builtin_extractor = BuiltinGraphExtractor (
82- name = name ,
170+ name = submodule_name ,
83171 dynamic = False ,
84172 mut_graph_codes = [],
85- placeholder_auto_rename = parent_graph_extractor .placeholder_auto_rename ,
86- workspace_path = self .parent_graph_extractor .config ["output_dir" ],
87- )
88- self .filter = self .make_filter (self .parent_graph_extractor .config )
89- self .post_extract_process = self .make_post_extract_process (
90- self .parent_graph_extractor .config
173+ placeholder_auto_rename = False ,
174+ workspace_path = self .config ["output_dir" ],
91175 )
176+ self .filter = self .make_filter (self .config )
177+ self .post_extract_process = self .make_post_extract_process (self .config )
92178
93179 def forward (self , * args ):
94180 if not self .extracted :
@@ -104,9 +190,7 @@ def need_extract(self, gm, sample_inputs):
104190 return self .filter (gm , sample_inputs )
105191
106192 def _post_extract_process (self ):
107- model_path = os .path .join (
108- self .parent_graph_extractor .config ["output_dir" ], self .model_name
109- )
193+ model_path = os .path .join (self .config ["output_dir" ], self .model_name )
110194 return self .post_extract_process (model_path )
111195
112196 def make_filter (self , config ):
0 commit comments