@@ -48,13 +48,8 @@ def make_config(
4848 output_dir = "./tmp/naive_decomposer_dir" ,
4949 filter_path = None ,
5050 filter_config = None ,
51- post_extract_process_path = None ,
52- post_extract_process_class_name = None ,
53- post_extract_process_config = None ,
5451 ** kwargs ,
5552 ):
56- if post_extract_process_config is None :
57- post_extract_process_config = {}
5853 for pos in split_positions :
5954 assert isinstance (
6055 pos , int
@@ -66,9 +61,6 @@ def make_config(
6661 "output_dir" : output_dir ,
6762 "filter_path" : filter_path ,
6863 "filter_config" : filter_config if filter_config is not None else {},
69- "post_extract_process_path" : post_extract_process_path ,
70- "post_extract_process_class_name" : post_extract_process_class_name ,
71- "post_extract_process_config" : post_extract_process_config ,
7264 }
7365
7466 def __call__ (self , gm : torch .fx .GraphModule , sample_inputs ):
@@ -111,14 +103,9 @@ def _make_config(
111103 chain_style = False ,
112104 filter_path = None ,
113105 filter_config = None ,
114- post_extract_process_path = None ,
115- post_extract_process_class_name = None ,
116- post_extract_process_config = None ,
117106 model_path_prefix = "" ,
118107 ** kwargs ,
119108 ):
120- if post_extract_process_config is None :
121- post_extract_process_config = {}
122109 for pos in split_positions :
123110 assert isinstance (
124111 pos , int
@@ -130,9 +117,6 @@ def _make_config(
130117 "output_dir" : output_dir ,
131118 "filter_path" : filter_path ,
132119 "filter_config" : filter_config if filter_config is not None else {},
133- "post_extract_process_path" : post_extract_process_path ,
134- "post_extract_process_class_name" : post_extract_process_class_name ,
135- "post_extract_process_config" : post_extract_process_config ,
136120 "model_path_prefix" : model_path_prefix ,
137121 }
138122
@@ -186,9 +170,6 @@ def _make_config(
186170 output_dir = "./tmp/naive_decomposer_dir" ,
187171 filter_path = None ,
188172 filter_config = None ,
189- post_extract_process_path = None ,
190- post_extract_process_class_name = None ,
191- post_extract_process_config = None ,
192173 model_path_prefix = "" ,
193174 ** kwargs ,
194175 ):
@@ -198,18 +179,13 @@ def _make_config(
198179 raise ValueError (
199180 f"split_results_path should be a valid JSON file path, but got { split_results_path = } "
200181 )
201- if post_extract_process_config is None :
202- post_extract_process_config = {}
203182 return {
204183 "split_results_path" : split_results_path ,
205184 "group_head_and_tail" : group_head_and_tail ,
206185 "chain_style" : chain_style ,
207186 "output_dir" : output_dir ,
208187 "filter_path" : filter_path ,
209188 "filter_config" : filter_config if filter_config is not None else {},
210- "post_extract_process_path" : post_extract_process_path ,
211- "post_extract_process_class_name" : post_extract_process_class_name ,
212- "post_extract_process_config" : post_extract_process_config ,
213189 "model_path_prefix" : model_path_prefix ,
214190 }
215191
@@ -274,7 +250,6 @@ def __init__(
274250 ),
275251 )
276252 self .filter = self .make_filter (self .config )
277- self .post_extract_process = self .make_post_extract_process (self .config )
278253
279254 def _get_model_path (self ):
280255 return os .path .join (
@@ -284,33 +259,19 @@ def _get_model_path(self):
284259 )
285260
286261 def forward (self , * args ):
287- logger .warning ("naive decomposer forwarding" )
288262 if not self .extracted :
289263 if self .need_extract (self .submodule , args ):
290264 self .builtin_extractor (self .submodule , args )
291- self ._post_extract_process ()
292265 self .extracted = True
293- logger .warning ("naive decomposer end" )
294266 return self .submodule (* args )
295267
296268 def need_extract (self , gm , sample_inputs ):
297269 if self .filter is None :
298270 return True
299271 return self .filter (gm , sample_inputs )
300272
301- def _post_extract_process (self ):
302- model_path = self ._get_model_path ()
303- return self .post_extract_process (model_path )
304-
305273 def make_filter (self , config ):
306274 if config ["filter_path" ] is None :
307275 return None
308276 module = imp_util .load_module (config ["filter_path" ])
309277 return module .GraphFilter (config ["filter_config" ])
310-
311- def make_post_extract_process (self , config ):
312- if config .get ("post_extract_process_path" ) is None :
313- return lambda * args , ** kwargs : None
314- module = imp_util .load_module (config ["post_extract_process_path" ])
315- cls = getattr (module , config ["post_extract_process_class_name" ])
316- return cls (config ["post_extract_process_config" ])
0 commit comments