11import os
22import torch
33import shutil
4- import inspect
54import tempfile
65
76from graph_net .torch .fx_graph_module_util import get_torch_module_and_inputs
@@ -92,16 +91,18 @@ def __call__(self, rel_model_path):
9291 with cuda_gc (enabled = self .config ["release_gpu_memory" ]):
9392 module , inputs = get_torch_module_and_inputs (src_model_path )
9493 gm = parse_sole_graph_module (module , inputs )
95- gm = self .rename_graph_variables (gm , inputs , src_model_path )
94+ gm , rename_map = self .rename_graph_variables (gm , inputs , src_model_path )
9695 del module , inputs
9796
9897 Path (dst_model_path ).parent .mkdir (parents = True , exist_ok = True )
9998 with tempfile .TemporaryDirectory (prefix = "graph_variable_renamer_" ) as temp_dir :
10099 temp_model_path = os .path .join (temp_dir , os .path .basename (dst_model_path ))
101100 shutil .copytree (src_model_path , temp_model_path , dirs_exist_ok = True )
102101 self ._update_model_py_file (gm , temp_model_path )
103- self ._update_weight_meta_py_file (src_model_path , temp_model_path )
104- self ._update_input_meta_py_file (src_model_path , temp_model_path )
102+ self ._update_weight_meta_py_file (
103+ src_model_path , temp_model_path , rename_map
104+ )
105+ self ._update_input_meta_py_file (src_model_path , temp_model_path , rename_map )
105106 print ("Try to run renamed model..." )
106107 self ._try_run (temp_model_path )
107108 shutil .copytree (temp_model_path , dst_model_path )
@@ -117,99 +118,99 @@ def _update_model_py_file(self, graph_module, model_path):
117118 file_hash = get_sha256_hash (py_code )
118119 (Path (model_path ) / "graph_hash.txt" ).write_text (file_hash )
119120
120- def _update_weight_meta_py_file (self , src_model_path , dst_model_path ):
121- old_name_to_new_name = self ._get_original_name_to_new_name (
122- src_model_path , dst_model_path
123- )
121+ def _update_weight_meta_py_file (self , src_model_path , dst_model_path , rename_map ):
124122 tensor_metas = TensorMeta .unserialize_from_py_file (
125123 os .path .join (src_model_path , "weight_meta.py" ),
126124 )
127125 for weight_meta in tensor_metas :
128- assert weight_meta .name in old_name_to_new_name
126+ meta_name = self ._find_name_in_rename_map (weight_meta .name , rename_map )
127+ assert (
128+ meta_name is not None
129+ ), f"{ weight_meta .name } is not found in rename_map"
129130 if weight_meta .original_name is None :
130131 weight_meta .original_name = weight_meta .name
131- weight_meta .name = old_name_to_new_name [weight_meta .name ]
132+ weight_meta .name = rename_map [meta_name ]
133+
132134 py_code = "\n \n " .join (
133135 [weight_meta .serialize_to_py_str () for weight_meta in tensor_metas ]
134136 )
135137 (Path (dst_model_path ) / "weight_meta.py" ).write_text (py_code )
136138
137- def _update_input_meta_py_file (self , src_model_path , dst_model_path ):
138- old_name_to_new_name = self ._get_original_name_to_new_name (
139- src_model_path , dst_model_path
140- )
139+ def _update_input_meta_py_file (self , src_model_path , dst_model_path , rename_map ):
141140 tensor_metas = TensorMeta .unserialize_from_py_file (
142141 os .path .join (src_model_path , "input_meta.py" ),
143142 )
144143 for input_meta in tensor_metas :
145- assert input_meta .name in old_name_to_new_name
144+ meta_name = self ._find_name_in_rename_map (input_meta .name , rename_map )
145+ assert (
146+ meta_name is not None
147+ ), f"{ input_meta .name } is not found in rename_map"
146148 if input_meta .original_name is None :
147149 input_meta .original_name = input_meta .name
148- input_meta .name = old_name_to_new_name [input_meta .name ]
150+ input_meta .name = rename_map [meta_name ]
151+
149152 py_code = "\n \n " .join (
150153 [input_meta .serialize_to_py_str () for input_meta in tensor_metas ]
151154 )
152155 (Path (dst_model_path ) / "input_meta.py" ).write_text (py_code )
153156
154- def _get_original_name_to_new_name (self , src_model_path , dst_model_path ):
155- src_model = self ._get_model (src_model_path )
156- dst_model = self ._get_model (dst_model_path )
157- old_name_and_new_name_pairs = zip (
158- self ._get_input_names_from_signature (src_model ),
159- self ._get_input_names_from_signature (dst_model ),
160- strict = True ,
161- )
162- return {
163- old_name : new_name for old_name , new_name in old_name_and_new_name_pairs
164- }
157+ def _find_name_in_rename_map (self , raw_name , rename_map ):
158+ if raw_name in rename_map :
159+ return raw_name
160+ # s1 -> s1_
161+ elif (raw_name + "_" ) in rename_map :
162+ return raw_name + "_"
163+ else :
164+ return None
165165
166166 def _get_model (self , model_path ):
167167 py_module = load_module (os .path .join (model_path , "model.py" ))
168168 GraphModule = getattr (py_module , "GraphModule" )
169169 GraphModule .__graph_net_file_path__ = py_module .__graph_net_file_path__
170170 return GraphModule ()
171171
172- def _get_input_names_from_signature (self , module ):
173- return inspect .signature (module .forward ).parameters
174-
175172 def rename_graph_variables (
176173 self , gm : torch .fx .GraphModule , sample_inputs , model_path
177174 ):
178175 counters = {"in" : 0 , "w" : 0 , "tmp" : 0 }
176+ rename_map = {}
179177 # graph may not have input, only contain weights
180178 arg_iter = iter (sample_inputs ) if sample_inputs else iter ([])
181179 for node in gm .graph .nodes :
182- self ._process_single_node (node , arg_iter , counters , model_path )
180+ self ._process_single_node (node , arg_iter , counters , model_path , rename_map )
183181 gm .graph .lint ()
184182 gm .recompile ()
185- return gm
183+ return gm , rename_map
186184
187- def _process_single_node (self , node , arg_iter , counters , model_path ):
185+ def _process_single_node (self , node , arg_iter , counters , model_path , rename_map ):
188186 if "original_name" not in node .meta :
189187 node .meta ["original_name" ] = node .name
190188 if node .op == "placeholder" :
191- self ._handle_placeholder (node , arg_iter , counters , model_path )
189+ self ._handle_placeholder (node , arg_iter , counters , model_path , rename_map )
192190 elif node .op == "get_attr" :
193- self ._apply_rename (node , "w" , counters )
191+ self ._apply_rename (node , "w" , counters , rename_map )
194192 elif node .op != "output" :
195- self ._apply_rename (node , "tmp" , counters )
193+ self ._apply_rename (node , "tmp" , counters , rename_map )
196194 else :
197195 # Do nothing
198196 pass
199197
200- def _handle_placeholder (self , node , arg_iter , counters , model_path ):
198+ def _handle_placeholder (self , node , arg_iter , counters , model_path , rename_map ):
201199 real_arg = next (arg_iter , None )
202200 is_weight = self ._is_weight_node (node , real_arg , model_path )
203201 prefix = "w" if is_weight else "in"
204- self ._apply_rename (node , prefix , counters , update_target = True )
202+ self ._apply_rename (node , prefix , counters , rename_map , update_target = True )
205203
206- def _apply_rename (self , node , prefix , counters , update_target = False ):
204+ def _apply_rename (self , node , prefix , counters , rename_map , update_target = False ):
205+ old_name = node .name
207206 new_name = f"{ prefix } _{ counters [prefix ]} "
208207 counters [prefix ] += 1
209208 node .name = new_name
210209 if update_target :
211210 node .target = new_name
212211
212+ rename_map [old_name ] = new_name
213+
213214 def _is_weight_node (self , node , real_arg , model_path ):
214215 is_not_data_input = not self .data_input_predicator (model_path , node .name )
215216 is_parameter_type = (
0 commit comments