Skip to content

Commit c122a9d

Browse files
committed
AST-based graph variable renamer
1 parent a4fe530 commit c122a9d

File tree

2 files changed

+185
-133
lines changed

2 files changed

+185
-133
lines changed

graph_net/test_compiler_util.py

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def print_with_log_prompt(key, value, log_prompt):
144144
def print_basic_config(args, hardware_name, compile_framework_version):
145145
model_path = os.path.normpath(args.model_path)
146146
model_name = get_model_name(model_path)
147-
print_with_log_prompt("[Config] model:", model_name, args.log_prompt)
147+
print_with_log_prompt("[Processing] model:", model_name, args.log_prompt)
148148

149149
print_with_log_prompt("[Config] device:", args.device, args.log_prompt)
150150
print_with_log_prompt("[Config] hardware:", hardware_name, args.log_prompt)
Lines changed: 184 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,116 @@
11
import os
2-
import torch
32
import shutil
43
import tempfile
5-
6-
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
7-
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
8-
from graph_net.tensor_meta import TensorMeta
4+
import ast
5+
import inspect
6+
import torch
97
from pathlib import Path
10-
from graph_net.torch.utils import apply_templates
118
from graph_net.imp_util import load_module
9+
from graph_net.tensor_meta import TensorMeta
1210
from graph_net.hash_util import get_sha256_hash
1311

1412

13+
class AstGraphRenamer(ast.NodeTransformer):
14+
def __init__(self, rename_map, input_arg_names, weight_arg_names):
15+
self.rename_map = rename_map
16+
self.input_arg_names = set(input_arg_names)
17+
self.weight_arg_names = set(weight_arg_names)
18+
self.counters = {"tmp": 0}
19+
self.in_forward = False
20+
21+
def visit_FunctionDef(self, node):
22+
if node.name == "forward":
23+
self.in_forward = True
24+
new_args = []
25+
for arg in node.args.args:
26+
if arg.arg == "self":
27+
new_args.append(arg)
28+
continue
29+
30+
if arg.arg in self.rename_map:
31+
new_arg_name = self.rename_map[arg.arg]
32+
new_args.append(
33+
ast.arg(arg=new_arg_name, annotation=arg.annotation)
34+
)
35+
else:
36+
new_args.append(arg)
37+
38+
node.args.args = new_args
39+
40+
new_body = []
41+
for stmt in node.body:
42+
stmt = self._remove_clear_stmt_of_args(stmt)
43+
if stmt is None:
44+
continue
45+
46+
stmt = self.visit(stmt)
47+
new_body.append(stmt)
48+
49+
node.body = new_body
50+
self.in_forward = False
51+
return node
52+
53+
def visit_Assign(self, node):
54+
if not self.in_forward:
55+
return node
56+
57+
for target in node.targets:
58+
if isinstance(target, ast.Name):
59+
old_name = target.id
60+
if old_name not in self.rename_map:
61+
new_name = f"tmp_{self.counters['tmp']}"
62+
self.counters["tmp"] += 1
63+
self.rename_map[old_name] = new_name
64+
65+
self.generic_visit(node)
66+
return node
67+
68+
def visit_Name(self, node):
69+
if not self.in_forward:
70+
return node
71+
if node.id in self.rename_map:
72+
return ast.Name(id=self.rename_map[node.id], ctx=node.ctx)
73+
return node
74+
75+
def _remove_clear_stmt_of_args(self, stmt):
76+
args_names = self.input_arg_names | self.weight_arg_names
77+
78+
def _need_remove(target):
79+
return isinstance(target, ast.Name) and target.id in args_names
80+
81+
if (
82+
isinstance(stmt, ast.Assign)
83+
and isinstance(stmt.value, ast.Constant)
84+
and stmt.value.value is None
85+
):
86+
# remove stmt like w_0 = None
87+
new_targets = [t for t in stmt.targets if not _need_remove(t)]
88+
if not new_targets:
89+
return None
90+
stmt.targets = new_targets
91+
elif isinstance(stmt, ast.Delete):
92+
# remove stmt like del w_0
93+
new_targets = []
94+
for t in stmt.targets:
95+
if isinstance(t, ast.Tuple):
96+
kept = [e for e in t.elts if not _need_remove(e)]
97+
if kept:
98+
new_targets.append(ast.Tuple(elts=kept, ctx=ast.Del()))
99+
elif not _need_remove(t):
100+
new_targets.append(t)
101+
if not new_targets:
102+
return None
103+
stmt.targets = new_targets
104+
return stmt
105+
106+
107+
def load_class_from_file(file_path: str, class_name: str):
108+
print(f"Load {class_name} from {file_path}")
109+
module = load_module(file_path, "unnamed_graph_module")
110+
model_class = getattr(module, class_name, None)
111+
return model_class
112+
113+
15114
class GraphVariableRenamer:
16115
"""
17116
Used by graph_net.model_path_handler
@@ -44,18 +143,13 @@ def _make_config(
44143
output_dir="./tmp/graph_variable_renamer_dir",
45144
filter_path=None,
46145
filter_config=None,
47-
post_extract_process_path=None,
48-
post_extract_process_class_name=None,
49-
post_extract_process_config=None,
50146
data_input_predicator_class_name="DataInputPredicator",
51147
model_runnable_predicator_class_name="ModelRunner",
52148
data_input_predicator_config=None,
53149
model_runnable_predicator_config=None,
54150
model_path_prefix="",
55151
**kwargs,
56152
):
57-
if post_extract_process_config is None:
58-
post_extract_process_config = {}
59153
if data_input_predicator_config is None:
60154
data_input_predicator_config = {}
61155
if model_runnable_predicator_config is None:
@@ -65,9 +159,6 @@ def _make_config(
65159
"output_dir": output_dir,
66160
"filter_path": filter_path,
67161
"filter_config": filter_config if filter_config is not None else {},
68-
"post_extract_process_path": post_extract_process_path,
69-
"post_extract_process_class_name": post_extract_process_class_name,
70-
"post_extract_process_config": post_extract_process_config,
71162
"data_input_predicator_filepath": data_input_predicator_filepath,
72163
"data_input_predicator_class_name": data_input_predicator_class_name,
73164
"data_input_predicator_config": data_input_predicator_config,
@@ -89,133 +180,94 @@ def __call__(self, rel_model_path):
89180
return
90181

91182
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
92-
module, inputs = get_torch_module_and_inputs(src_model_path)
93-
gm = parse_sole_graph_module(module, inputs)
94-
gm, rename_map = self.rename_graph_variables(gm, inputs, src_model_path)
95-
96183
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
184+
graph_module = load_class_from_file(
185+
os.path.join(src_model_path, "model.py"), class_name="GraphModule"
186+
)
187+
input_arg_names, weight_arg_names = self._get_input_and_weight_arg_names(
188+
graph_module, src_model_path
189+
)
190+
191+
rename_map = {}
192+
for idx, name in enumerate(input_arg_names):
193+
rename_map[name] = f"in_{idx}"
194+
for idx, name in enumerate(weight_arg_names):
195+
rename_map[name] = f"w_{idx}"
196+
97197
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
98198
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
99199
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
100-
self._update_model_py_file(gm, temp_model_path)
101-
self._update_weight_meta_py_file(
102-
src_model_path, temp_model_path, rename_map
200+
self._update_model_py_file(
201+
temp_model_path, rename_map, input_arg_names, weight_arg_names
103202
)
104-
self._update_input_meta_py_file(src_model_path, temp_model_path, rename_map)
105-
# print("Try to run renamed model...")
106-
# self._try_run(temp_model_path)
107-
shutil.copytree(temp_model_path, dst_model_path)
203+
self._update_meta_file(temp_model_path, "weight_meta.py", rename_map)
204+
self._update_meta_file(temp_model_path, "input_meta.py", rename_map)
205+
print(f"Verifying {rel_model_path}...")
206+
self._try_run(temp_model_path)
207+
shutil.copytree(temp_model_path, dst_model_path, dirs_exist_ok=True)
108208

109-
def _try_run(self, model_path):
110-
assert self.model_runnable_predicator(
111-
model_path
112-
), f"{model_path} is not a runnable model"
209+
def _get_input_and_weight_arg_names(self, graph_module, model_path):
210+
input_arg_names = []
211+
weight_arg_names = []
212+
sig = inspect.signature(graph_module.forward)
213+
for name, param in sig.parameters.items():
214+
if name == "self":
215+
continue
216+
is_not_data_input = not self.data_input_predicator(model_path, name)
217+
is_parameter_type = self._is_parameter_type(param.annotation)
218+
if is_not_data_input or is_parameter_type:
219+
weight_arg_names.append(name)
220+
else:
221+
input_arg_names.append(name)
222+
return input_arg_names, weight_arg_names
113223

114-
def _update_model_py_file(self, graph_module, model_path):
115-
py_code = apply_templates(graph_module.code)
116-
(Path(model_path) / "model.py").write_text(py_code)
117-
file_hash = get_sha256_hash(py_code)
118-
(Path(model_path) / "graph_hash.txt").write_text(file_hash)
224+
def _is_parameter_type(self, annotation):
225+
return annotation is torch.nn.parameter.Parameter
119226

120-
def _update_weight_meta_py_file(self, src_model_path, dst_model_path, rename_map):
121-
tensor_metas = TensorMeta.unserialize_from_py_file(
122-
os.path.join(src_model_path, "weight_meta.py"),
123-
)
124-
for weight_meta in tensor_metas:
125-
meta_name = self._find_name_in_rename_map(weight_meta.name, rename_map)
126-
assert (
127-
meta_name is not None
128-
), f"{weight_meta.name} is not found in rename_map"
129-
if weight_meta.original_name is None:
130-
weight_meta.original_name = weight_meta.name
131-
weight_meta.name = rename_map[meta_name]
132-
133-
py_code = "\n\n".join(
134-
[weight_meta.serialize_to_py_str() for weight_meta in tensor_metas]
135-
)
136-
(Path(dst_model_path) / "weight_meta.py").write_text(py_code)
227+
def _update_model_py_file(
228+
self, model_path, rename_map, input_arg_names, weight_arg_names
229+
):
230+
model_file = Path(model_path) / "model.py"
231+
with open(model_file, "r", encoding="utf-8") as f:
232+
source = f.read()
137233

138-
def _update_input_meta_py_file(self, src_model_path, dst_model_path, rename_map):
139-
tensor_metas = TensorMeta.unserialize_from_py_file(
140-
os.path.join(src_model_path, "input_meta.py"),
141-
)
142-
for input_meta in tensor_metas:
143-
meta_name = self._find_name_in_rename_map(input_meta.name, rename_map)
144-
assert (
145-
meta_name is not None
146-
), f"{input_meta.name} is not found in rename_map"
147-
if input_meta.original_name is None:
148-
input_meta.original_name = input_meta.name
149-
input_meta.name = rename_map[meta_name]
150-
151-
py_code = "\n\n".join(
152-
[input_meta.serialize_to_py_str() for input_meta in tensor_metas]
153-
)
154-
(Path(dst_model_path) / "input_meta.py").write_text(py_code)
155-
156-
def _find_name_in_rename_map(self, raw_name, rename_map):
157-
if raw_name in rename_map:
158-
return raw_name
159-
# s1 -> s1_
160-
elif (raw_name + "_") in rename_map:
161-
return raw_name + "_"
234+
tree = ast.parse(source)
235+
for node in tree.body:
236+
if isinstance(node, ast.ClassDef) and node.name == "GraphModule":
237+
transformer = AstGraphRenamer(
238+
rename_map, input_arg_names, weight_arg_names
239+
)
240+
transformer.visit(node)
241+
break
242+
243+
if hasattr(ast, "unparse"):
244+
py_code = ast.unparse(tree)
162245
else:
163-
return None
246+
import astor
164247

165-
def _get_model(self, model_path):
166-
py_module = load_module(os.path.join(model_path, "model.py"))
167-
GraphModule = getattr(py_module, "GraphModule")
168-
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
169-
return GraphModule()
248+
py_code = astor.to_source(tree)
170249

171-
def rename_graph_variables(
172-
self, gm: torch.fx.GraphModule, sample_inputs, model_path
173-
):
174-
counters = {"in": 0, "w": 0, "tmp": 0}
175-
rename_map = {}
176-
# graph may not have input, only contain weights
177-
arg_iter = iter(sample_inputs) if sample_inputs else iter([])
178-
for node in gm.graph.nodes:
179-
self._process_single_node(node, arg_iter, counters, model_path, rename_map)
180-
gm.graph.lint()
181-
gm.recompile()
182-
return gm, rename_map
183-
184-
def _process_single_node(self, node, arg_iter, counters, model_path, rename_map):
185-
if "original_name" not in node.meta:
186-
node.meta["original_name"] = node.name
187-
if node.op == "placeholder":
188-
self._handle_placeholder(node, arg_iter, counters, model_path, rename_map)
189-
elif node.op == "get_attr":
190-
self._apply_rename(node, "w", counters, rename_map)
191-
elif node.op != "output":
192-
self._apply_rename(node, "tmp", counters, rename_map)
193-
else:
194-
# Do nothing
195-
pass
196-
197-
def _handle_placeholder(self, node, arg_iter, counters, model_path, rename_map):
198-
real_arg = next(arg_iter, None)
199-
is_weight = self._is_weight_node(node, real_arg, model_path)
200-
prefix = "w" if is_weight else "in"
201-
self._apply_rename(node, prefix, counters, rename_map, update_target=True)
202-
203-
def _apply_rename(self, node, prefix, counters, rename_map, update_target=False):
204-
old_name = node.name
205-
new_name = f"{prefix}_{counters[prefix]}"
206-
counters[prefix] += 1
207-
node.name = new_name
208-
if update_target:
209-
node.target = new_name
210-
211-
rename_map[old_name] = new_name
212-
213-
def _is_weight_node(self, node, real_arg, model_path):
214-
is_not_data_input = not self.data_input_predicator(model_path, node.name)
215-
is_parameter_type = (
216-
node.type is not None
217-
and isinstance(node.type, type)
218-
and issubclass(node.type, torch.nn.parameter.Parameter)
219-
)
220-
is_parameter_value = isinstance(real_arg, torch.nn.Parameter)
221-
return is_not_data_input or is_parameter_type or is_parameter_value
250+
model_file.write_text(py_code, encoding="utf-8")
251+
file_hash = get_sha256_hash(py_code)
252+
(Path(model_path) / "graph_hash.txt").write_text(file_hash)
253+
254+
def _update_meta_file(self, model_path, meta_filename, rename_map):
255+
meta_file = Path(model_path) / meta_filename
256+
tensor_metas = TensorMeta.unserialize_from_py_file(str(meta_file))
257+
for meta in tensor_metas:
258+
if meta.name in rename_map:
259+
if meta.original_name is None:
260+
meta.original_name = meta.name
261+
meta.name = rename_map[meta.name]
262+
else:
263+
print(
264+
f"[Warning] {meta.name} in {meta_filename} not found in rename_map, skipping."
265+
)
266+
267+
py_code = "\n\n".join([meta.serialize_to_py_str() for meta in tensor_metas])
268+
meta_file.write_text(py_code)
269+
270+
def _try_run(self, model_path):
271+
assert self.model_runnable_predicator(
272+
model_path
273+
), f"{model_path} is not a runnable model"

0 commit comments

Comments
 (0)