Skip to content

Commit e6b5544

Browse files
authored
[Feature Enhancement] Add AST-based graph variable renamer (#477)
* [Feature Enhancement] AST-based graph variable renamer * fix * refactor AstGraphVariableRenamer with samplepass
1 parent 0fefebe commit e6b5544

File tree

4 files changed

+323
-9
lines changed

4 files changed

+323
-9
lines changed
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
from graph_net.sample_pass.sample_pass import SamplePass
2+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
3+
from pathlib import Path
4+
import os
5+
import shutil
6+
import tempfile
7+
import ast
8+
import inspect
9+
import torch
10+
from graph_net.imp_util import load_module
11+
from graph_net.tensor_meta import TensorMeta
12+
from graph_net.hash_util import get_sha256_hash
13+
14+
15+
class AstGraphVariableRenamer(SamplePass, ResumableSamplePassMixin):
16+
def __init__(self, config):
17+
super().__init__(config)
18+
self.data_input_predicator = self._make_data_input_predicator(self.config)
19+
self.model_runnable_predicator = self._make_model_runnable_predicator(
20+
self.config
21+
)
22+
23+
def _make_data_input_predicator(self, config):
24+
module = load_module(config["data_input_predicator_filepath"])
25+
cls = getattr(module, config["data_input_predicator_class_name"])
26+
return cls(config["data_input_predicator_config"])
27+
28+
def _make_model_runnable_predicator(self, config):
29+
module = load_module(config["model_runnable_predicator_filepath"])
30+
cls = getattr(module, config["model_runnable_predicator_class_name"])
31+
return cls(config["model_runnable_predicator_config"])
32+
33+
def declare_config(
34+
self,
35+
model_path_prefix: str,
36+
output_dir: str,
37+
device: str,
38+
resume: bool = False,
39+
limits_handled_models: int = None,
40+
data_input_predicator_filepath: str = None,
41+
data_input_predicator_class_name: str = None,
42+
data_input_predicator_config: dict = None,
43+
model_runnable_predicator_filepath: str = None,
44+
model_runnable_predicator_class_name: str = None,
45+
model_runnable_predicator_config: dict = None,
46+
):
47+
pass
48+
49+
def __call__(self, rel_model_path: str):
50+
self.resumable_handle_sample(rel_model_path)
51+
52+
def sample_handled(self, rel_model_path: str) -> bool:
53+
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
54+
55+
def resume(self, rel_model_path: str):
56+
torch.cuda.empty_cache()
57+
dst_model_path = os.path.realpath(
58+
os.path.join(self.config["output_dir"], rel_model_path)
59+
)
60+
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
61+
graph_module_class = load_class_from_file(
62+
os.path.join(src_model_path, "model.py"), class_name="GraphModule"
63+
)
64+
input_arg_names, weight_arg_names = self._get_input_and_weight_arg_names(
65+
graph_module_class, src_model_path
66+
)
67+
rename_map = self._create_rename_map(input_arg_names, weight_arg_names)
68+
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
69+
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
70+
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
71+
self._update_model_py_file(
72+
temp_model_path, rename_map, input_arg_names, weight_arg_names
73+
)
74+
self._update_meta_file(temp_model_path, "weight_meta.py", rename_map)
75+
self._update_meta_file(temp_model_path, "input_meta.py", rename_map)
76+
self._try_run(temp_model_path)
77+
shutil.copytree(temp_model_path, dst_model_path, dirs_exist_ok=True)
78+
79+
def _get_input_and_weight_arg_names(self, graph_module, model_path):
80+
input_arg_names = []
81+
weight_arg_names = []
82+
sig = inspect.signature(graph_module.forward)
83+
for name, param in sig.parameters.items():
84+
if name == "self":
85+
continue
86+
is_not_data_input = not self.data_input_predicator(model_path, name)
87+
if is_not_data_input:
88+
weight_arg_names.append(name)
89+
else:
90+
input_arg_names.append(name)
91+
return input_arg_names, weight_arg_names
92+
93+
def _create_rename_map(self, input_arg_names, weight_arg_names):
94+
rename_map = {}
95+
for idx, name in enumerate(input_arg_names):
96+
rename_map[name] = f"in_{idx}"
97+
for idx, name in enumerate(weight_arg_names):
98+
rename_map[name] = f"w_{idx}"
99+
return rename_map
100+
101+
def _update_model_py_file(
102+
self, model_path, rename_map, input_arg_names, weight_arg_names
103+
):
104+
model_file = Path(model_path) / "model.py"
105+
source = model_file.read_text(encoding="utf-8")
106+
tree = ast.parse(source)
107+
node = self._get_graph_module_ast(tree)
108+
graph_renamer = AstGraphRenamer(rename_map, input_arg_names, weight_arg_names)
109+
graph_renamer.visit(node)
110+
py_code = ast.unparse(tree)
111+
model_file.write_text(py_code, encoding="utf-8")
112+
file_hash = get_sha256_hash(py_code)
113+
(Path(model_path) / "graph_hash.txt").write_text(file_hash)
114+
115+
def _get_graph_module_ast(self, tree):
116+
for node in tree.body:
117+
if isinstance(node, ast.ClassDef) and node.name == "GraphModule":
118+
return node
119+
return None
120+
121+
def _update_meta_file(self, model_path, meta_filename, rename_map):
122+
meta_file = Path(model_path) / meta_filename
123+
tensor_metas = TensorMeta.unserialize_from_py_file(str(meta_file))
124+
for meta in tensor_metas:
125+
assert (
126+
meta.name in rename_map
127+
), f"[Warning] {meta.name} in {meta_filename} not found in rename_map."
128+
if meta.original_name is None:
129+
meta.original_name = meta.name
130+
meta.name = rename_map[meta.name]
131+
132+
py_code = "\n\n".join([meta.serialize_to_py_str() for meta in tensor_metas])
133+
meta_file.write_text(py_code)
134+
135+
def _try_run(self, model_path):
136+
(f"[AstGraphVariableRenamer] Try to run {model_path}")
137+
assert self.model_runnable_predicator(
138+
model_path
139+
), f"{model_path} is not a runnable model"
140+
141+
142+
def load_class_from_file(file_path: str, class_name: str):
143+
print(f"Load {class_name} from {file_path}")
144+
module = load_module(file_path, "unnamed_graph_module")
145+
model_class = getattr(module, class_name, None)
146+
return model_class
147+
148+
149+
class AstGraphRenamer(ast.NodeTransformer):
150+
def __init__(self, rename_map, input_arg_names, weight_arg_names):
151+
self.rename_map = rename_map
152+
self.input_and_weight_arg_names = set(input_arg_names) | set(weight_arg_names)
153+
self.counters = {"tmp": 0}
154+
self.in_forward = False
155+
156+
def visit_FunctionDef(self, node):
157+
if node.name != "forward":
158+
return node
159+
self.in_forward = True
160+
node.args.args = self._rename_function_args(node.args.args)
161+
node.body = self._rename_function_body(node.body)
162+
self.in_forward = False
163+
return node
164+
165+
def _rename_function_args(self, args):
166+
new_function_args = []
167+
for arg in args:
168+
if arg.arg == "self":
169+
new_function_args.append(arg)
170+
else:
171+
new_function_args.append(self._create_renamed_arg(arg))
172+
return new_function_args
173+
174+
def _create_renamed_arg(self, arg):
175+
if arg.arg in self.rename_map:
176+
return ast.arg(arg=self.rename_map[arg.arg], annotation=arg.annotation)
177+
return arg
178+
179+
def _rename_function_body(self, body):
180+
new_function_body = []
181+
for stmt in body:
182+
stmt = self._remove_clear_stmt_of_args(stmt)
183+
if stmt:
184+
stmt = self.visit(stmt)
185+
new_function_body.append(stmt)
186+
return new_function_body
187+
188+
def _remove_clear_stmt_of_args(self, stmt):
189+
# remove stmt like w_0 = None
190+
if self._is_assign_none(stmt):
191+
return self._clean_assign_none(stmt)
192+
# remove stmt like del w_0
193+
elif isinstance(stmt, ast.Delete):
194+
return self._clean_delete(stmt)
195+
else:
196+
pass
197+
return stmt
198+
199+
def _is_assign_none(self, stmt):
200+
return (
201+
isinstance(stmt, ast.Assign)
202+
and isinstance(stmt.value, ast.Constant)
203+
and stmt.value.value is None
204+
)
205+
206+
def _clean_assign_none(self, stmt):
207+
new_targets = [t for t in stmt.targets if not self._is_input_or_weight_var(t)]
208+
if not new_targets:
209+
return None
210+
stmt.targets = new_targets
211+
return stmt
212+
213+
def _is_input_or_weight_var(self, target):
214+
return (
215+
isinstance(target, ast.Name)
216+
and target.id in self.input_and_weight_arg_names
217+
)
218+
219+
def _clean_delete(self, stmt):
220+
new_targets = []
221+
for target in stmt.targets:
222+
kept = self._filter_delete_target(target)
223+
if kept:
224+
new_targets.append(kept)
225+
226+
if not new_targets:
227+
return None
228+
stmt.targets = new_targets
229+
return stmt
230+
231+
def _filter_delete_target(self, target):
232+
if isinstance(target, ast.Tuple): # del (a, b)
233+
kept_elts = [e for e in target.elts if not self._is_protected_var(e)]
234+
return ast.Tuple(elts=kept_elts, ctx=ast.Del()) if kept_elts else None
235+
elif not self._is_protected_var(target): # del a
236+
return target
237+
else:
238+
pass
239+
return None
240+
241+
def visit_Assign(self, node):
242+
if not self.in_forward:
243+
return node
244+
self._register_new_local_variables(node.targets)
245+
self.generic_visit(node)
246+
return node
247+
248+
def _register_new_local_variables(self, targets):
249+
for target in targets:
250+
for name in self._flatten_assignment_target(target):
251+
self._register_if_unknown(name)
252+
253+
def _flatten_assignment_target(self, target):
254+
if isinstance(target, ast.Name):
255+
yield target.id
256+
elif isinstance(target, (ast.Tuple, ast.List)):
257+
for elt in target.elts:
258+
yield from self._flatten_assignment_target(elt)
259+
else:
260+
pass
261+
262+
def _register_if_unknown(self, name):
263+
if name not in self.rename_map:
264+
new_name = f"tmp_{self.counters['tmp']}"
265+
self.counters["tmp"] += 1
266+
self.rename_map[name] = new_name
267+
268+
def visit_Name(self, node):
269+
if not self.in_forward:
270+
return node
271+
if node.id in self.rename_map:
272+
return ast.Name(id=self.rename_map[node.id], ctx=node.ctx)
273+
return node
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
RENAMED_PATH=/tmp/ast_graph_variable_rename_workspace
5+
6+
mkdir -p "$RENAMED_PATH"
7+
model_list="$GRAPH_NET_ROOT/graph_net/config/small100_torch_samples_list.txt"
8+
9+
python3 -m graph_net.model_path_handler \
10+
--model-path-list $model_list \
11+
--handler-config=$(base64 -w 0 <<EOF
12+
{
13+
"handler_path": "$GRAPH_NET_ROOT/graph_net/sample_pass/ast_graph_variable_renamer.py",
14+
"handler_class_name": "AstGraphVariableRenamer",
15+
"handler_config": {
16+
"device": "cuda",
17+
"resume": true,
18+
"model_path_prefix": "$GRAPH_NET_ROOT/",
19+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
20+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
21+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
22+
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
23+
"output_dir": "$RENAMED_PATH"
24+
}
25+
}
26+
EOF
27+
) \
28+
2>&1 | tee "$RENAMED_PATH/graph_rename.log"
29+
30+
python3 -m graph_net.torch.test_compiler \
31+
--model-path-prefix $GRAPH_NET_ROOT \
32+
--allow-list $model_list \
33+
--compiler graph_variable_renamer_validator \
34+
--device cuda \
35+
--config $(base64 -w 0 <<EOF
36+
{
37+
"model_path_prefix": "$GRAPH_NET_ROOT",
38+
"renamed_root": "$RENAMED_PATH"
39+
}
40+
EOF
41+
) \
42+
2>&1 | tee "$RENAMED_PATH/validation.log"
43+
44+
python3 -m graph_net.plot_ESt \
45+
--benchmark-path "$RENAMED_PATH/validation.log" \
46+
--output-dir "$RENAMED_PATH"

graph_net/torch/graph_variable_renamer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,13 @@ def _make_config(
4444
output_dir="./tmp/graph_variable_renamer_dir",
4545
filter_path=None,
4646
filter_config=None,
47-
post_extract_process_path=None,
48-
post_extract_process_class_name=None,
49-
post_extract_process_config=None,
5047
data_input_predicator_class_name="DataInputPredicator",
5148
model_runnable_predicator_class_name="ModelRunner",
5249
data_input_predicator_config=None,
5350
model_runnable_predicator_config=None,
5451
model_path_prefix="",
5552
**kwargs,
5653
):
57-
if post_extract_process_config is None:
58-
post_extract_process_config = {}
5954
if data_input_predicator_config is None:
6055
data_input_predicator_config = {}
6156
if model_runnable_predicator_config is None:
@@ -65,9 +60,6 @@ def _make_config(
6560
"output_dir": output_dir,
6661
"filter_path": filter_path,
6762
"filter_config": filter_config if filter_config is not None else {},
68-
"post_extract_process_path": post_extract_process_path,
69-
"post_extract_process_class_name": post_extract_process_class_name,
70-
"post_extract_process_config": post_extract_process_config,
7163
"data_input_predicator_filepath": data_input_predicator_filepath,
7264
"data_input_predicator_class_name": data_input_predicator_class_name,
7365
"data_input_predicator_config": data_input_predicator_config,

graph_net/torch/test_compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ def test_single_model(args):
204204
compiler = get_compiler_backend(args)
205205
input_dict = get_input_dict(args)
206206
model = get_model(args)
207-
207+
model_path = os.path.normpath(args.model_path)
208+
test_compiler_util.print_with_log_prompt(
209+
"[Processing]", model_path, args.log_prompt
210+
)
208211
test_compiler_util.print_basic_config(
209212
args, get_hardward_name(args), get_compile_framework_version(args)
210213
)

0 commit comments

Comments
 (0)