11import os
22import torch
33import shutil
4- import inspect
54import tempfile
65from graph_net .torch .fx_graph_module_util import get_torch_module_and_inputs
76from graph_net .torch .fx_graph_parse_util import parse_sole_graph_module
@@ -81,7 +80,7 @@ def __call__(self, rel_model_path):
8180 src_model_path = os .path .join (self .config ["model_path_prefix" ], rel_model_path )
8281 module , inputs = get_torch_module_and_inputs (src_model_path )
8382 gm = parse_sole_graph_module (module , inputs )
84- gm = self .rename_graph_variables (gm , inputs , src_model_path )
83+ gm , rename_map = self .rename_graph_variables (gm , inputs , src_model_path )
8584 dst_model_path = os .path .realpath (
8685 os .path .join (self .config ["output_dir" ], rel_model_path )
8786 )
@@ -94,8 +93,10 @@ def __call__(self, rel_model_path):
9493 temp_model_path = os .path .join (temp_dir , os .path .basename (dst_model_path ))
9594 shutil .copytree (src_model_path , temp_model_path , dirs_exist_ok = True )
9695 self ._update_model_py_file (gm , temp_model_path )
97- self ._update_weight_meta_py_file (src_model_path , temp_model_path )
98- self ._update_input_meta_py_file (src_model_path , temp_model_path )
96+ self ._update_weight_meta_py_file (
97+ src_model_path , temp_model_path , rename_map
98+ )
99+ self ._update_input_meta_py_file (src_model_path , temp_model_path , rename_map )
99100 print ("Try to run renamed model..." )
100101 self ._try_run (temp_model_path )
101102 shutil .copytree (temp_model_path , dst_model_path )
@@ -111,99 +112,99 @@ def _update_model_py_file(self, graph_module, model_path):
111112 file_hash = get_sha256_hash (py_code )
112113 (Path (model_path ) / "graph_hash.txt" ).write_text (file_hash )
113114
114- def _update_weight_meta_py_file (self , src_model_path , dst_model_path ):
115- old_name_to_new_name = self ._get_original_name_to_new_name (
116- src_model_path , dst_model_path
117- )
115+ def _update_weight_meta_py_file (self , src_model_path , dst_model_path , rename_map ):
118116 tensor_metas = TensorMeta .unserialize_from_py_file (
119117 os .path .join (src_model_path , "weight_meta.py" ),
120118 )
121119 for weight_meta in tensor_metas :
122- assert weight_meta .name in old_name_to_new_name
120+ meta_name = self ._find_name_in_rename_map (weight_meta .name , rename_map )
121+ assert (
122+ meta_name is not None
123+ ), f"{ weight_meta .name } is not found in rename_map"
123124 if weight_meta .original_name is None :
124125 weight_meta .original_name = weight_meta .name
125- weight_meta .name = old_name_to_new_name [weight_meta .name ]
126+ weight_meta .name = rename_map [meta_name ]
127+
126128 py_code = "\n \n " .join (
127129 [weight_meta .serialize_to_py_str () for weight_meta in tensor_metas ]
128130 )
129131 (Path (dst_model_path ) / "weight_meta.py" ).write_text (py_code )
130132
131- def _update_input_meta_py_file (self , src_model_path , dst_model_path ):
132- old_name_to_new_name = self ._get_original_name_to_new_name (
133- src_model_path , dst_model_path
134- )
133+ def _update_input_meta_py_file (self , src_model_path , dst_model_path , rename_map ):
135134 tensor_metas = TensorMeta .unserialize_from_py_file (
136135 os .path .join (src_model_path , "input_meta.py" ),
137136 )
138137 for input_meta in tensor_metas :
139- assert input_meta .name in old_name_to_new_name
138+ meta_name = self ._find_name_in_rename_map (input_meta .name , rename_map )
139+ assert (
140+ meta_name is not None
141+ ), f"{ input_meta .name } is not found in rename_map"
140142 if input_meta .original_name is None :
141143 input_meta .original_name = input_meta .name
142- input_meta .name = old_name_to_new_name [input_meta .name ]
144+ input_meta .name = rename_map [meta_name ]
145+
143146 py_code = "\n \n " .join (
144147 [input_meta .serialize_to_py_str () for input_meta in tensor_metas ]
145148 )
146149 (Path (dst_model_path ) / "input_meta.py" ).write_text (py_code )
147150
148- def _get_original_name_to_new_name (self , src_model_path , dst_model_path ):
149- src_model = self ._get_model (src_model_path )
150- dst_model = self ._get_model (dst_model_path )
151- old_name_and_new_name_pairs = zip (
152- self ._get_input_names_from_signature (src_model ),
153- self ._get_input_names_from_signature (dst_model ),
154- strict = True ,
155- )
156- return {
157- old_name : new_name for old_name , new_name in old_name_and_new_name_pairs
158- }
151+ def _find_name_in_rename_map (self , raw_name , rename_map ):
152+ if raw_name in rename_map :
153+ return raw_name
154+ # s1 -> s1_
155+ elif (raw_name + "_" ) in rename_map :
156+ return raw_name + "_"
157+ else :
158+ return None
159159
160160 def _get_model (self , model_path ):
161161 py_module = load_module (os .path .join (model_path , "model.py" ))
162162 GraphModule = getattr (py_module , "GraphModule" )
163163 GraphModule .__graph_net_file_path__ = py_module .__graph_net_file_path__
164164 return GraphModule ()
165165
166- def _get_input_names_from_signature (self , module ):
167- return inspect .signature (module .forward ).parameters
168-
169166 def rename_graph_variables (
170167 self , gm : torch .fx .GraphModule , sample_inputs , model_path
171168 ):
172169 counters = {"in" : 0 , "w" : 0 , "tmp" : 0 }
170+ rename_map = {}
173171 # graph may not have input, only contain weights
174172 arg_iter = iter (sample_inputs ) if sample_inputs else iter ([])
175173 for node in gm .graph .nodes :
176- self ._process_single_node (node , arg_iter , counters , model_path )
174+ self ._process_single_node (node , arg_iter , counters , model_path , rename_map )
177175 gm .graph .lint ()
178176 gm .recompile ()
179- return gm
177+ return gm , rename_map
180178
181- def _process_single_node (self , node , arg_iter , counters , model_path ):
179+ def _process_single_node (self , node , arg_iter , counters , model_path , rename_map ):
182180 if "original_name" not in node .meta :
183181 node .meta ["original_name" ] = node .name
184182 if node .op == "placeholder" :
185- self ._handle_placeholder (node , arg_iter , counters , model_path )
183+ self ._handle_placeholder (node , arg_iter , counters , model_path , rename_map )
186184 elif node .op == "get_attr" :
187- self ._apply_rename (node , "w" , counters )
185+ self ._apply_rename (node , "w" , counters , rename_map )
188186 elif node .op != "output" :
189- self ._apply_rename (node , "tmp" , counters )
187+ self ._apply_rename (node , "tmp" , counters , rename_map )
190188 else :
191189 # Do nothing
192190 pass
193191
194- def _handle_placeholder (self , node , arg_iter , counters , model_path ):
192+ def _handle_placeholder (self , node , arg_iter , counters , model_path , rename_map ):
195193 real_arg = next (arg_iter , None )
196194 is_weight = self ._is_weight_node (node , real_arg , model_path )
197195 prefix = "w" if is_weight else "in"
198- self ._apply_rename (node , prefix , counters , update_target = True )
196+ self ._apply_rename (node , prefix , counters , rename_map , update_target = True )
199197
200- def _apply_rename (self , node , prefix , counters , update_target = False ):
198+ def _apply_rename (self , node , prefix , counters , rename_map , update_target = False ):
199+ old_name = node .name
201200 new_name = f"{ prefix } _{ counters [prefix ]} "
202201 counters [prefix ] += 1
203202 node .name = new_name
204203 if update_target :
205204 node .target = new_name
206205
206+ rename_map [old_name ] = new_name
207+
207208 def _is_weight_node (self , node , real_arg , model_path ):
208209 is_not_data_input = not self .data_input_predicator (model_path , node .name )
209210 is_parameter_type = (
0 commit comments