Skip to content

Commit b687ea8

Browse files
committed
[Feature Enhancement] AST-based graph variable renamer
1 parent 43c2ab5 commit b687ea8

File tree

4 files changed

+341
-9
lines changed

4 files changed

+341
-9
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
RENAMED_PATH=/tmp/ast_graph_variable_rename_workspace
5+
6+
mkdir -p "$RENAMED_PATH"
7+
model_list="$GRAPH_NET_ROOT/graph_net/config/small100_torch_samples_list.txt"
8+
9+
python3 -m graph_net.model_path_handler \
10+
--model-path-list $model_list \
11+
--handler-config=$(base64 -w 0 <<EOF
12+
{
13+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/ast_graph_variable_renamer.py",
14+
"handler_class_name": "AstGraphVariableRenamer",
15+
"handler_config": {
16+
"device": "cuda",
17+
"resume": true,
18+
"model_path_prefix": "$GRAPH_NET_ROOT/",
19+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
20+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
21+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
22+
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
23+
"output_dir": "$RENAMED_PATH"
24+
}
25+
}
26+
EOF
27+
) \
28+
2>&1 | tee "$RENAMED_PATH/graph_rename.log"
29+
30+
python3 -m graph_net.torch.test_compiler \
31+
--model-path-prefix $GRAPH_NET_ROOT \
32+
--allow-list $model_list \
33+
--compiler graph_variable_renamer_validator \
34+
--device cuda \
35+
--config $(base64 -w 0 <<EOF
36+
{
37+
"model_path_prefix": "$GRAPH_NET_ROOT",
38+
"renamed_root": "$RENAMED_PATH"
39+
}
40+
EOF
41+
) \
42+
2>&1 | tee "$RENAMED_PATH/validation.log"
43+
44+
python3 -m graph_net.plot_ESt \
45+
--benchmark-path "$RENAMED_PATH/validation.log" \
46+
--output-dir "$RENAMED_PATH"
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
import os
2+
import shutil
3+
import tempfile
4+
import ast
5+
import inspect
6+
import torch
7+
from pathlib import Path
8+
from graph_net.imp_util import load_module
9+
from graph_net.tensor_meta import TensorMeta
10+
from graph_net.hash_util import get_sha256_hash
11+
12+
13+
class AstGraphRenamer(ast.NodeTransformer):
14+
def __init__(self, rename_map, input_arg_names, weight_arg_names):
15+
self.rename_map = rename_map
16+
self.input_and_weight_arg_names = set(input_arg_names) | set(weight_arg_names)
17+
self.counters = {"tmp": 0}
18+
self.in_forward = False
19+
20+
def visit_FunctionDef(self, node):
21+
if node.name != "forward":
22+
return node
23+
self.in_forward = True
24+
node.args.args = self._rename_function_args(node.args.args)
25+
node.body = self._rename_function_body(node.body)
26+
self.in_forward = False
27+
return node
28+
29+
def _rename_function_args(self, args):
30+
new_function_args = []
31+
for arg in args:
32+
if arg.arg == "self":
33+
new_function_args.append(arg)
34+
else:
35+
new_function_args.append(self._create_renamed_arg(arg))
36+
return new_function_args
37+
38+
def _create_renamed_arg(self, arg):
39+
if arg.arg in self.rename_map:
40+
return ast.arg(arg=self.rename_map[arg.arg], annotation=arg.annotation)
41+
return arg
42+
43+
def _rename_function_body(self, body):
44+
new_function_body = []
45+
for stmt in body:
46+
stmt = self._remove_clear_stmt_of_args(stmt)
47+
if stmt:
48+
stmt = self.visit(stmt)
49+
new_function_body.append(stmt)
50+
return new_function_body
51+
52+
def _remove_clear_stmt_of_args(self, stmt):
53+
# remove stmt like w_0 = None
54+
if self._is_assign_none(stmt):
55+
return self._clean_assign_none(stmt)
56+
# remove stmt like del w_0
57+
elif isinstance(stmt, ast.Delete):
58+
return self._clean_delete(stmt)
59+
else:
60+
pass
61+
return stmt
62+
63+
def _is_assign_none(self, stmt):
64+
return (
65+
isinstance(stmt, ast.Assign)
66+
and isinstance(stmt.value, ast.Constant)
67+
and stmt.value.value is None
68+
)
69+
70+
def _clean_assign_none(self, stmt):
71+
new_targets = [t for t in stmt.targets if not self._is_input_or_weight_var(t)]
72+
if not new_targets:
73+
return None
74+
stmt.targets = new_targets
75+
return stmt
76+
77+
def _is_input_or_weight_var(self, target):
78+
return (
79+
isinstance(target, ast.Name)
80+
and target.id in self.input_and_weight_arg_names
81+
)
82+
83+
def _clean_delete(self, stmt):
84+
new_targets = []
85+
for target in stmt.targets:
86+
kept = self._filter_delete_target(target)
87+
if kept:
88+
new_targets.append(kept)
89+
90+
if not new_targets:
91+
return None
92+
stmt.targets = new_targets
93+
return stmt
94+
95+
def _filter_delete_target(self, target):
96+
if isinstance(target, ast.Tuple): # del (a, b)
97+
kept_elts = [e for e in target.elts if not self._is_protected_var(e)]
98+
return ast.Tuple(elts=kept_elts, ctx=ast.Del()) if kept_elts else None
99+
elif not self._is_protected_var(target): # del a
100+
return target
101+
return None
102+
103+
def visit_Assign(self, node):
104+
if not self.in_forward:
105+
return node
106+
self._register_new_local_variables(node.targets)
107+
self.generic_visit(node)
108+
return node
109+
110+
def _register_new_local_variables(self, targets):
111+
for target in targets:
112+
for name in self._flatten_assignment_target(target):
113+
self._register_if_unknown(name)
114+
115+
def _flatten_assignment_target(self, target):
116+
if isinstance(target, ast.Name):
117+
yield target.id
118+
elif isinstance(target, (ast.Tuple, ast.List)):
119+
for elt in target.elts:
120+
yield from self._flatten_assignment_target(elt)
121+
else:
122+
pass
123+
124+
def _register_if_unknown(self, name):
125+
if name not in self.rename_map:
126+
new_name = f"tmp_{self.counters['tmp']}"
127+
self.counters["tmp"] += 1
128+
self.rename_map[name] = new_name
129+
130+
def visit_Name(self, node):
131+
if not self.in_forward:
132+
return node
133+
if node.id in self.rename_map:
134+
return ast.Name(id=self.rename_map[node.id], ctx=node.ctx)
135+
return node
136+
137+
138+
def load_class_from_file(file_path: str, class_name: str):
139+
print(f"Load {class_name} from {file_path}")
140+
module = load_module(file_path, "unnamed_graph_module")
141+
model_class = getattr(module, class_name, None)
142+
return model_class
143+
144+
145+
class AstGraphVariableRenamer:
146+
"""
147+
Used by graph_net.model_path_handler
148+
"""
149+
150+
def __init__(self, config: dict = None):
151+
if config is None:
152+
config = {}
153+
self.config = self._make_config(**config)
154+
self.data_input_predicator = self._make_data_input_predicator(self.config)
155+
self.model_runnable_predicator = self._make_model_runnable_predicator(
156+
self.config
157+
)
158+
159+
def _make_data_input_predicator(self, config):
160+
module = load_module(config["data_input_predicator_filepath"])
161+
cls = getattr(module, config["data_input_predicator_class_name"])
162+
return cls(config["data_input_predicator_config"])
163+
164+
def _make_model_runnable_predicator(self, config):
165+
module = load_module(config["model_runnable_predicator_filepath"])
166+
cls = getattr(module, config["model_runnable_predicator_class_name"])
167+
return cls(config["model_runnable_predicator_config"])
168+
169+
def _make_config(
170+
self,
171+
resume: bool = False,
172+
data_input_predicator_filepath=None,
173+
model_runnable_predicator_filepath=None,
174+
output_dir="./tmp/graph_variable_renamer_dir",
175+
filter_path=None,
176+
filter_config=None,
177+
data_input_predicator_class_name="DataInputPredicator",
178+
model_runnable_predicator_class_name="ModelRunner",
179+
data_input_predicator_config=None,
180+
model_runnable_predicator_config=None,
181+
model_path_prefix="",
182+
**kwargs,
183+
):
184+
if data_input_predicator_config is None:
185+
data_input_predicator_config = {}
186+
if model_runnable_predicator_config is None:
187+
model_runnable_predicator_config = {}
188+
return {
189+
"resume": resume,
190+
"output_dir": output_dir,
191+
"filter_path": filter_path,
192+
"filter_config": filter_config if filter_config is not None else {},
193+
"data_input_predicator_filepath": data_input_predicator_filepath,
194+
"data_input_predicator_class_name": data_input_predicator_class_name,
195+
"data_input_predicator_config": data_input_predicator_config,
196+
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
197+
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
198+
"model_runnable_predicator_config": model_runnable_predicator_config,
199+
"model_path_prefix": model_path_prefix,
200+
}
201+
202+
def __call__(self, rel_model_path):
203+
torch.cuda.empty_cache()
204+
205+
dst_model_path = os.path.realpath(
206+
os.path.join(self.config["output_dir"], rel_model_path)
207+
)
208+
if self.config["resume"] and os.path.exists(
209+
os.path.join(dst_model_path, "model.py")
210+
):
211+
return
212+
213+
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
214+
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
215+
graph_module_class = load_class_from_file(
216+
os.path.join(src_model_path, "model.py"), class_name="GraphModule"
217+
)
218+
input_arg_names, weight_arg_names = self._get_input_and_weight_arg_names(
219+
graph_module_class, src_model_path
220+
)
221+
222+
rename_map = {}
223+
for idx, name in enumerate(input_arg_names):
224+
rename_map[name] = f"in_{idx}"
225+
for idx, name in enumerate(weight_arg_names):
226+
rename_map[name] = f"w_{idx}"
227+
228+
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
229+
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
230+
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
231+
self._update_model_py_file(
232+
temp_model_path, rename_map, input_arg_names, weight_arg_names
233+
)
234+
self._update_meta_file(temp_model_path, "weight_meta.py", rename_map)
235+
self._update_meta_file(temp_model_path, "input_meta.py", rename_map)
236+
self._try_run(temp_model_path)
237+
shutil.copytree(temp_model_path, dst_model_path, dirs_exist_ok=True)
238+
239+
def _get_input_and_weight_arg_names(self, graph_module, model_path):
240+
input_arg_names = []
241+
weight_arg_names = []
242+
sig = inspect.signature(graph_module.forward)
243+
for name, param in sig.parameters.items():
244+
if name == "self":
245+
continue
246+
is_not_data_input = not self.data_input_predicator(model_path, name)
247+
if is_not_data_input:
248+
weight_arg_names.append(name)
249+
else:
250+
input_arg_names.append(name)
251+
return input_arg_names, weight_arg_names
252+
253+
def _update_model_py_file(
254+
self, model_path, rename_map, input_arg_names, weight_arg_names
255+
):
256+
model_file = Path(model_path) / "model.py"
257+
source = model_file.read_text(encoding="utf-8")
258+
tree = ast.parse(source)
259+
node = self._get_graph_module_ast(tree)
260+
graph_renamer = AstGraphRenamer(rename_map, input_arg_names, weight_arg_names)
261+
graph_renamer.visit(node)
262+
py_code = ast.unparse(tree)
263+
model_file.write_text(py_code, encoding="utf-8")
264+
file_hash = get_sha256_hash(py_code)
265+
(Path(model_path) / "graph_hash.txt").write_text(file_hash)
266+
267+
def _get_graph_module_ast(self, tree):
268+
for node in tree.body:
269+
if isinstance(node, ast.ClassDef) and node.name == "GraphModule":
270+
return node
271+
return None
272+
273+
def _update_meta_file(self, model_path, meta_filename, rename_map):
274+
meta_file = Path(model_path) / meta_filename
275+
tensor_metas = TensorMeta.unserialize_from_py_file(str(meta_file))
276+
for meta in tensor_metas:
277+
assert (
278+
meta.name in rename_map
279+
), f"[Warning] {meta.name} in {meta_filename} not found in rename_map."
280+
if meta.original_name is None:
281+
meta.original_name = meta.name
282+
meta.name = rename_map[meta.name]
283+
284+
py_code = "\n\n".join([meta.serialize_to_py_str() for meta in tensor_metas])
285+
meta_file.write_text(py_code)
286+
287+
def _try_run(self, model_path):
288+
(f"[AstGraphVariableRenamer] Try to run {model_path}")
289+
assert self.model_runnable_predicator(
290+
model_path
291+
), f"{model_path} is not a runnable model"

graph_net/torch/graph_variable_renamer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,13 @@ def _make_config(
4444
output_dir="./tmp/graph_variable_renamer_dir",
4545
filter_path=None,
4646
filter_config=None,
47-
post_extract_process_path=None,
48-
post_extract_process_class_name=None,
49-
post_extract_process_config=None,
5047
data_input_predicator_class_name="DataInputPredicator",
5148
model_runnable_predicator_class_name="ModelRunner",
5249
data_input_predicator_config=None,
5350
model_runnable_predicator_config=None,
5451
model_path_prefix="",
5552
**kwargs,
5653
):
57-
if post_extract_process_config is None:
58-
post_extract_process_config = {}
5954
if data_input_predicator_config is None:
6055
data_input_predicator_config = {}
6156
if model_runnable_predicator_config is None:
@@ -65,9 +60,6 @@ def _make_config(
6560
"output_dir": output_dir,
6661
"filter_path": filter_path,
6762
"filter_config": filter_config if filter_config is not None else {},
68-
"post_extract_process_path": post_extract_process_path,
69-
"post_extract_process_class_name": post_extract_process_class_name,
70-
"post_extract_process_config": post_extract_process_config,
7163
"data_input_predicator_filepath": data_input_predicator_filepath,
7264
"data_input_predicator_class_name": data_input_predicator_class_name,
7365
"data_input_predicator_config": data_input_predicator_config,

graph_net/torch/test_compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ def test_single_model(args):
204204
compiler = get_compiler_backend(args)
205205
input_dict = get_input_dict(args)
206206
model = get_model(args)
207-
207+
model_path = os.path.normpath(args.model_path)
208+
test_compiler_util.print_with_log_prompt(
209+
"[Processing]", model_path, args.log_prompt
210+
)
208211
test_compiler_util.print_basic_config(
209212
args, get_hardward_name(args), get_compile_framework_version(args)
210213
)

0 commit comments

Comments
 (0)