@@ -31,8 +31,11 @@ def make_config(
3131 filter_config = None ,
3232 post_extract_process_path = None ,
3333 post_extract_process_class_name = None ,
34+ post_extract_process_config = None ,
3435 ** kwargs ,
3536 ):
37+ if post_extract_process_config is None :
38+ post_extract_process_config = {}
3639 for pos in split_positions :
3740 assert isinstance (
3841 pos , int
@@ -46,6 +49,7 @@ def make_config(
4649 "filter_config" : filter_config if filter_config is not None else {},
4750 "post_extract_process_path" : post_extract_process_path ,
4851 "post_extract_process_class_name" : post_extract_process_class_name ,
52+ "post_extract_process_config" : post_extract_process_config ,
4953 }
5054
5155 def __call__ (self , gm : torch .fx .GraphModule , sample_inputs ):
@@ -112,8 +116,8 @@ def make_filter(self, config):
112116 return module .GraphFilter (config ["filter_config" ])
113117
114118 def make_post_extract_process (self , config ):
115- if config [ "post_extract_process_path" ] is None :
116- return None
119+ if config . get ( "post_extract_process_path" ) is None :
120+ return lambda * args , ** kwargs : None
117121 module = imp_util .load_module (config ["post_extract_process_path" ])
118122 cls = getattr (module , config ["post_extract_process_class_name" ])
119- return cls (config ["post_extract_process_path " ])
123+ return cls (config ["post_extract_process_config " ])
0 commit comments