11import os
2- import torch
32import shutil
43import tempfile
5-
6- from graph_net .torch .fx_graph_module_util import get_torch_module_and_inputs
7- from graph_net .torch .fx_graph_parse_util import parse_sole_graph_module
8- from graph_net .tensor_meta import TensorMeta
4+ import ast
5+ import inspect
6+ import torch
97from pathlib import Path
10- from graph_net .torch .utils import apply_templates
118from graph_net .imp_util import load_module
9+ from graph_net .tensor_meta import TensorMeta
1210from graph_net .hash_util import get_sha256_hash
1311
1412
13+ class AstGraphRenamer (ast .NodeTransformer ):
14+ def __init__ (self , rename_map , input_arg_names , weight_arg_names ):
15+ self .rename_map = rename_map
16+ self .input_arg_names = set (input_arg_names )
17+ self .weight_arg_names = set (weight_arg_names )
18+ self .counters = {"tmp" : 0 }
19+ self .in_forward = False
20+
21+ def visit_FunctionDef (self , node ):
22+ if node .name == "forward" :
23+ self .in_forward = True
24+ new_args = []
25+ for arg in node .args .args :
26+ if arg .arg == "self" :
27+ new_args .append (arg )
28+ continue
29+
30+ if arg .arg in self .rename_map :
31+ new_arg_name = self .rename_map [arg .arg ]
32+ new_args .append (
33+ ast .arg (arg = new_arg_name , annotation = arg .annotation )
34+ )
35+ else :
36+ new_args .append (arg )
37+
38+ node .args .args = new_args
39+
40+ new_body = []
41+ for stmt in node .body :
42+ stmt = self ._remove_clear_stmt_of_args (stmt )
43+ if stmt is None :
44+ continue
45+
46+ stmt = self .visit (stmt )
47+ new_body .append (stmt )
48+
49+ node .body = new_body
50+ self .in_forward = False
51+ return node
52+
53+ def visit_Assign (self , node ):
54+ if not self .in_forward :
55+ return node
56+
57+ for target in node .targets :
58+ if isinstance (target , ast .Name ):
59+ old_name = target .id
60+ if old_name not in self .rename_map :
61+ new_name = f"tmp_{ self .counters ['tmp' ]} "
62+ self .counters ["tmp" ] += 1
63+ self .rename_map [old_name ] = new_name
64+
65+ self .generic_visit (node )
66+ return node
67+
68+ def visit_Name (self , node ):
69+ if not self .in_forward :
70+ return node
71+ if node .id in self .rename_map :
72+ return ast .Name (id = self .rename_map [node .id ], ctx = node .ctx )
73+ return node
74+
75+ def _remove_clear_stmt_of_args (self , stmt ):
76+ args_names = self .input_arg_names | self .weight_arg_names
77+
78+ def _need_remove (target ):
79+ return isinstance (target , ast .Name ) and target .id in args_names
80+
81+ if (
82+ isinstance (stmt , ast .Assign )
83+ and isinstance (stmt .value , ast .Constant )
84+ and stmt .value .value is None
85+ ):
86+ # remove stmt like w_0 = None
87+ new_targets = [t for t in stmt .targets if not _need_remove (t )]
88+ if not new_targets :
89+ return None
90+ stmt .targets = new_targets
91+ elif isinstance (stmt , ast .Delete ):
92+ # remove stmt like del w_0
93+ new_targets = []
94+ for t in stmt .targets :
95+ if isinstance (t , ast .Tuple ):
96+ kept = [e for e in t .elts if not _need_remove (e )]
97+ if kept :
98+ new_targets .append (ast .Tuple (elts = kept , ctx = ast .Del ()))
99+ elif not _need_remove (t ):
100+ new_targets .append (t )
101+ if not new_targets :
102+ return None
103+ stmt .targets = new_targets
104+ return stmt
105+
106+
107+ def load_class_from_file (file_path : str , class_name : str ):
108+ print (f"Load { class_name } from { file_path } " )
109+ module = load_module (file_path , "unnamed_graph_module" )
110+ model_class = getattr (module , class_name , None )
111+ return model_class
112+
113+
15114class GraphVariableRenamer :
16115 """
17116 Used by graph_net.model_path_handler
@@ -44,18 +143,13 @@ def _make_config(
44143 output_dir = "./tmp/graph_variable_renamer_dir" ,
45144 filter_path = None ,
46145 filter_config = None ,
47- post_extract_process_path = None ,
48- post_extract_process_class_name = None ,
49- post_extract_process_config = None ,
50146 data_input_predicator_class_name = "DataInputPredicator" ,
51147 model_runnable_predicator_class_name = "ModelRunner" ,
52148 data_input_predicator_config = None ,
53149 model_runnable_predicator_config = None ,
54150 model_path_prefix = "" ,
55151 ** kwargs ,
56152 ):
57- if post_extract_process_config is None :
58- post_extract_process_config = {}
59153 if data_input_predicator_config is None :
60154 data_input_predicator_config = {}
61155 if model_runnable_predicator_config is None :
@@ -65,9 +159,6 @@ def _make_config(
65159 "output_dir" : output_dir ,
66160 "filter_path" : filter_path ,
67161 "filter_config" : filter_config if filter_config is not None else {},
68- "post_extract_process_path" : post_extract_process_path ,
69- "post_extract_process_class_name" : post_extract_process_class_name ,
70- "post_extract_process_config" : post_extract_process_config ,
71162 "data_input_predicator_filepath" : data_input_predicator_filepath ,
72163 "data_input_predicator_class_name" : data_input_predicator_class_name ,
73164 "data_input_predicator_config" : data_input_predicator_config ,
@@ -89,133 +180,94 @@ def __call__(self, rel_model_path):
89180 return
90181
91182 src_model_path = os .path .join (self .config ["model_path_prefix" ], rel_model_path )
92- module , inputs = get_torch_module_and_inputs (src_model_path )
93- gm = parse_sole_graph_module (module , inputs )
94- gm , rename_map = self .rename_graph_variables (gm , inputs , src_model_path )
95-
96183 Path (dst_model_path ).parent .mkdir (parents = True , exist_ok = True )
184+ graph_module = load_class_from_file (
185+ os .path .join (src_model_path , "model.py" ), class_name = "GraphModule"
186+ )
187+ input_arg_names , weight_arg_names = self ._get_input_and_weight_arg_names (
188+ graph_module , src_model_path
189+ )
190+
191+ rename_map = {}
192+ for idx , name in enumerate (input_arg_names ):
193+ rename_map [name ] = f"in_{ idx } "
194+ for idx , name in enumerate (weight_arg_names ):
195+ rename_map [name ] = f"w_{ idx } "
196+
97197 with tempfile .TemporaryDirectory (prefix = "graph_variable_renamer_" ) as temp_dir :
98198 temp_model_path = os .path .join (temp_dir , os .path .basename (dst_model_path ))
99199 shutil .copytree (src_model_path , temp_model_path , dirs_exist_ok = True )
100- self ._update_model_py_file (gm , temp_model_path )
101- self ._update_weight_meta_py_file (
102- src_model_path , temp_model_path , rename_map
200+ self ._update_model_py_file (
201+ temp_model_path , rename_map , input_arg_names , weight_arg_names
103202 )
104- self ._update_input_meta_py_file (src_model_path , temp_model_path , rename_map )
105- # print("Try to run renamed model...")
106- # self._try_run(temp_model_path)
107- shutil .copytree (temp_model_path , dst_model_path )
203+ self ._update_meta_file (temp_model_path , "weight_meta.py" , rename_map )
204+ self ._update_meta_file (temp_model_path , "input_meta.py" , rename_map )
205+ print (f"Verifying { rel_model_path } ..." )
206+ self ._try_run (temp_model_path )
207+ shutil .copytree (temp_model_path , dst_model_path , dirs_exist_ok = True )
108208
109- def _try_run (self , model_path ):
110- assert self .model_runnable_predicator (
111- model_path
112- ), f"{ model_path } is not a runnable model"
209+ def _get_input_and_weight_arg_names (self , graph_module , model_path ):
210+ input_arg_names = []
211+ weight_arg_names = []
212+ sig = inspect .signature (graph_module .forward )
213+ for name , param in sig .parameters .items ():
214+ if name == "self" :
215+ continue
216+ is_not_data_input = not self .data_input_predicator (model_path , name )
217+ is_parameter_type = self ._is_parameter_type (param .annotation )
218+ if is_not_data_input or is_parameter_type :
219+ weight_arg_names .append (name )
220+ else :
221+ input_arg_names .append (name )
222+ return input_arg_names , weight_arg_names
113223
114- def _update_model_py_file (self , graph_module , model_path ):
115- py_code = apply_templates (graph_module .code )
116- (Path (model_path ) / "model.py" ).write_text (py_code )
117- file_hash = get_sha256_hash (py_code )
118- (Path (model_path ) / "graph_hash.txt" ).write_text (file_hash )
224+ def _is_parameter_type (self , annotation ):
225+ return annotation is torch .nn .parameter .Parameter
119226
120- def _update_weight_meta_py_file (self , src_model_path , dst_model_path , rename_map ):
121- tensor_metas = TensorMeta .unserialize_from_py_file (
122- os .path .join (src_model_path , "weight_meta.py" ),
123- )
124- for weight_meta in tensor_metas :
125- meta_name = self ._find_name_in_rename_map (weight_meta .name , rename_map )
126- assert (
127- meta_name is not None
128- ), f"{ weight_meta .name } is not found in rename_map"
129- if weight_meta .original_name is None :
130- weight_meta .original_name = weight_meta .name
131- weight_meta .name = rename_map [meta_name ]
132-
133- py_code = "\n \n " .join (
134- [weight_meta .serialize_to_py_str () for weight_meta in tensor_metas ]
135- )
136- (Path (dst_model_path ) / "weight_meta.py" ).write_text (py_code )
227+ def _update_model_py_file (
228+ self , model_path , rename_map , input_arg_names , weight_arg_names
229+ ):
230+ model_file = Path (model_path ) / "model.py"
231+ with open (model_file , "r" , encoding = "utf-8" ) as f :
232+ source = f .read ()
137233
138- def _update_input_meta_py_file (self , src_model_path , dst_model_path , rename_map ):
139- tensor_metas = TensorMeta .unserialize_from_py_file (
140- os .path .join (src_model_path , "input_meta.py" ),
141- )
142- for input_meta in tensor_metas :
143- meta_name = self ._find_name_in_rename_map (input_meta .name , rename_map )
144- assert (
145- meta_name is not None
146- ), f"{ input_meta .name } is not found in rename_map"
147- if input_meta .original_name is None :
148- input_meta .original_name = input_meta .name
149- input_meta .name = rename_map [meta_name ]
150-
151- py_code = "\n \n " .join (
152- [input_meta .serialize_to_py_str () for input_meta in tensor_metas ]
153- )
154- (Path (dst_model_path ) / "input_meta.py" ).write_text (py_code )
155-
156- def _find_name_in_rename_map (self , raw_name , rename_map ):
157- if raw_name in rename_map :
158- return raw_name
159- # s1 -> s1_
160- elif (raw_name + "_" ) in rename_map :
161- return raw_name + "_"
234+ tree = ast .parse (source )
235+ for node in tree .body :
236+ if isinstance (node , ast .ClassDef ) and node .name == "GraphModule" :
237+ transformer = AstGraphRenamer (
238+ rename_map , input_arg_names , weight_arg_names
239+ )
240+ transformer .visit (node )
241+ break
242+
243+ if hasattr (ast , "unparse" ):
244+ py_code = ast .unparse (tree )
162245 else :
163- return None
246+ import astor
164247
165- def _get_model (self , model_path ):
166- py_module = load_module (os .path .join (model_path , "model.py" ))
167- GraphModule = getattr (py_module , "GraphModule" )
168- GraphModule .__graph_net_file_path__ = py_module .__graph_net_file_path__
169- return GraphModule ()
248+ py_code = astor .to_source (tree )
170249
171- def rename_graph_variables (
172- self , gm : torch .fx .GraphModule , sample_inputs , model_path
173- ):
174- counters = {"in" : 0 , "w" : 0 , "tmp" : 0 }
175- rename_map = {}
176- # graph may not have input, only contain weights
177- arg_iter = iter (sample_inputs ) if sample_inputs else iter ([])
178- for node in gm .graph .nodes :
179- self ._process_single_node (node , arg_iter , counters , model_path , rename_map )
180- gm .graph .lint ()
181- gm .recompile ()
182- return gm , rename_map
183-
184- def _process_single_node (self , node , arg_iter , counters , model_path , rename_map ):
185- if "original_name" not in node .meta :
186- node .meta ["original_name" ] = node .name
187- if node .op == "placeholder" :
188- self ._handle_placeholder (node , arg_iter , counters , model_path , rename_map )
189- elif node .op == "get_attr" :
190- self ._apply_rename (node , "w" , counters , rename_map )
191- elif node .op != "output" :
192- self ._apply_rename (node , "tmp" , counters , rename_map )
193- else :
194- # Do nothing
195- pass
196-
197- def _handle_placeholder (self , node , arg_iter , counters , model_path , rename_map ):
198- real_arg = next (arg_iter , None )
199- is_weight = self ._is_weight_node (node , real_arg , model_path )
200- prefix = "w" if is_weight else "in"
201- self ._apply_rename (node , prefix , counters , rename_map , update_target = True )
202-
203- def _apply_rename (self , node , prefix , counters , rename_map , update_target = False ):
204- old_name = node .name
205- new_name = f"{ prefix } _{ counters [prefix ]} "
206- counters [prefix ] += 1
207- node .name = new_name
208- if update_target :
209- node .target = new_name
210-
211- rename_map [old_name ] = new_name
212-
213- def _is_weight_node (self , node , real_arg , model_path ):
214- is_not_data_input = not self .data_input_predicator (model_path , node .name )
215- is_parameter_type = (
216- node .type is not None
217- and isinstance (node .type , type )
218- and issubclass (node .type , torch .nn .parameter .Parameter )
219- )
220- is_parameter_value = isinstance (real_arg , torch .nn .Parameter )
221- return is_not_data_input or is_parameter_type or is_parameter_value
250+ model_file .write_text (py_code , encoding = "utf-8" )
251+ file_hash = get_sha256_hash (py_code )
252+ (Path (model_path ) / "graph_hash.txt" ).write_text (file_hash )
253+
254+ def _update_meta_file (self , model_path , meta_filename , rename_map ):
255+ meta_file = Path (model_path ) / meta_filename
256+ tensor_metas = TensorMeta .unserialize_from_py_file (str (meta_file ))
257+ for meta in tensor_metas :
258+ if meta .name in rename_map :
259+ if meta .original_name is None :
260+ meta .original_name = meta .name
261+ meta .name = rename_map [meta .name ]
262+ else :
263+ print (
264+ f"[Warning] { meta .name } in { meta_filename } not found in rename_map, skipping."
265+ )
266+
267+ py_code = "\n \n " .join ([meta .serialize_to_py_str () for meta in tensor_metas ])
268+ meta_file .write_text (py_code )
269+
270+ def _try_run (self , model_path ):
271+ assert self .model_runnable_predicator (
272+ model_path
273+ ), f"{ model_path } is not a runnable model"
0 commit comments