@@ -33,7 +33,7 @@ def make_config(
3333 filter_path = None ,
3434 filter_config = None ,
3535 post_extract_process_path = None ,
36- post_extract_process_config = None ,
36+ post_extract_process_class_name = None ,
3737 ):
3838 for pos in split_positions :
3939 assert isinstance (
@@ -47,7 +47,7 @@ def make_config(
4747 "filter_path" : filter_path ,
4848 "filter_config" : filter_config if filter_config is not None else {},
4949 "post_extract_process_path" : post_extract_process_path ,
50- "post_extract_process_config " : post_extract_process_config ,
50+ "post_extract_process_class_name " : post_extract_process_class_name ,
5151 }
5252
5353 def __call__ (self , gm : torch .fx .GraphModule , sample_inputs ):
@@ -75,7 +75,7 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
7575 self .seq_no = seq_no
7676 self .extracted = False
7777 name = f"{ parent_graph_extractor .name } _{ self .seq_no } "
78- self .modelname = name
78+ self .model_name = name
7979 self .builtin_extractor = BuiltinGraphExtractor (
8080 name = name ,
8181 dynamic = False ,
@@ -103,7 +103,7 @@ def need_extract(self, gm, sample_inputs):
103103
104104 def _post_extract_process (self ):
105105 model_path = os .path .join (
106- self .parent_graph_extractor .config ["output_dir" ], self .modelname
106+ self .parent_graph_extractor .config ["output_dir" ], self .model_name
107107 )
108108 return self .post_extract_process (model_path )
109109
@@ -117,4 +117,4 @@ def make_post_extract_process(self, config):
117117 if config ["post_extract_process_path" ] is None :
118118 return None
119119 module = imp_util .load_module (config ["post_extract_process_path" ])
120- return module .PostExtractProcess (config ["post_extract_process_config " ])
120+ return module .PostExtractProcess (config ["post_extract_process_path " ])
0 commit comments