Skip to content

Commit 4321b73

Browse files
committed
code bakup
2 parents c05b3c7 + df56c31 commit 4321b73

File tree

8 files changed

+428
-10
lines changed

8 files changed

+428
-10
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
from graph_net.sample_pass.sample_pass import SamplePass
2+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
3+
from pathlib import Path
4+
import os
5+
import shutil
6+
import tempfile
7+
import ast
8+
import inspect
9+
import torch
10+
from graph_net.imp_util import load_module
11+
from graph_net.tensor_meta import TensorMeta
12+
from graph_net.hash_util import get_sha256_hash
13+
14+
15+
class AstGraphVariableRenamer(SamplePass, ResumableSamplePassMixin):
16+
def __init__(self, config):
17+
super().__init__(config)
18+
self.data_input_predicator = self._make_data_input_predicator(self.config)
19+
self.model_runnable_predicator = self._make_model_runnable_predicator(
20+
self.config
21+
)
22+
23+
def _make_data_input_predicator(self, config):
24+
module = load_module(config["data_input_predicator_filepath"])
25+
cls = getattr(module, config["data_input_predicator_class_name"])
26+
return cls(config["data_input_predicator_config"])
27+
28+
def _make_model_runnable_predicator(self, config):
29+
module = load_module(config["model_runnable_predicator_filepath"])
30+
cls = getattr(module, config["model_runnable_predicator_class_name"])
31+
return cls(config["model_runnable_predicator_config"])
32+
33+
def declare_config(
34+
self,
35+
model_path_prefix: str,
36+
output_dir: str,
37+
device: str,
38+
resume: bool = False,
39+
try_run: bool = False,
40+
limits_handled_models: int = None,
41+
data_input_predicator_filepath: str = None,
42+
data_input_predicator_class_name: str = None,
43+
data_input_predicator_config: dict = None,
44+
model_runnable_predicator_filepath: str = None,
45+
model_runnable_predicator_class_name: str = None,
46+
model_runnable_predicator_config: dict = None,
47+
):
48+
pass
49+
50+
def __call__(self, rel_model_path: str):
51+
self.resumable_handle_sample(rel_model_path)
52+
53+
def sample_handled(self, rel_model_path: str) -> bool:
54+
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
55+
56+
def resume(self, rel_model_path: str):
57+
torch.cuda.empty_cache()
58+
dst_model_path = os.path.realpath(
59+
os.path.join(self.config["output_dir"], rel_model_path)
60+
)
61+
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
62+
graph_module_class = load_class_from_file(
63+
os.path.join(src_model_path, "model.py"), class_name="GraphModule"
64+
)
65+
input_arg_names, weight_arg_names = self._get_input_and_weight_arg_names(
66+
graph_module_class, src_model_path
67+
)
68+
rename_map = self._create_rename_map(input_arg_names, weight_arg_names)
69+
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
70+
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
71+
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
72+
self._update_model_py_file(
73+
temp_model_path, rename_map, input_arg_names, weight_arg_names
74+
)
75+
self._update_meta_file(temp_model_path, "weight_meta.py", rename_map)
76+
self._update_meta_file(temp_model_path, "input_meta.py", rename_map)
77+
if self.config["try_run"]:
78+
self._try_run(temp_model_path)
79+
shutil.copytree(temp_model_path, dst_model_path, dirs_exist_ok=True)
80+
81+
def _get_input_and_weight_arg_names(self, graph_module, model_path):
82+
input_arg_names = []
83+
weight_arg_names = []
84+
sig = inspect.signature(graph_module.forward)
85+
for name, param in sig.parameters.items():
86+
if name == "self":
87+
continue
88+
is_not_data_input = not self.data_input_predicator(model_path, name)
89+
if is_not_data_input:
90+
weight_arg_names.append(name)
91+
else:
92+
input_arg_names.append(name)
93+
return input_arg_names, weight_arg_names
94+
95+
def _create_rename_map(self, input_arg_names, weight_arg_names):
96+
rename_map = {}
97+
for idx, name in enumerate(input_arg_names):
98+
rename_map[name] = f"in_{idx}"
99+
for idx, name in enumerate(weight_arg_names):
100+
rename_map[name] = f"w_{idx}"
101+
return rename_map
102+
103+
def _update_model_py_file(
104+
self, model_path, rename_map, input_arg_names, weight_arg_names
105+
):
106+
model_file = Path(model_path) / "model.py"
107+
source = model_file.read_text(encoding="utf-8")
108+
tree = ast.parse(source)
109+
node = self._get_graph_module_ast(tree)
110+
graph_renamer = AstGraphRenamer(rename_map, input_arg_names, weight_arg_names)
111+
graph_renamer.visit(node)
112+
py_code = ast.unparse(tree)
113+
model_file.write_text(py_code, encoding="utf-8")
114+
file_hash = get_sha256_hash(py_code)
115+
(Path(model_path) / "graph_hash.txt").write_text(file_hash)
116+
117+
def _get_graph_module_ast(self, tree):
118+
for node in tree.body:
119+
if isinstance(node, ast.ClassDef) and node.name == "GraphModule":
120+
return node
121+
return None
122+
123+
def _update_meta_file(self, model_path, meta_filename, rename_map):
124+
meta_file = Path(model_path) / meta_filename
125+
tensor_metas = TensorMeta.unserialize_from_py_file(str(meta_file))
126+
for meta in tensor_metas:
127+
assert (
128+
meta.name in rename_map
129+
), f"[Warning] {meta.name} in {meta_filename} not found in rename_map."
130+
if meta.original_name is None:
131+
meta.original_name = meta.name
132+
meta.name = rename_map[meta.name]
133+
134+
py_code = "\n\n".join([meta.serialize_to_py_str() for meta in tensor_metas])
135+
meta_file.write_text(py_code)
136+
137+
def _try_run(self, model_path):
138+
print(f"[AstGraphVariableRenamer] Try to run {model_path}")
139+
assert self.model_runnable_predicator(
140+
model_path
141+
), f"{model_path} is not a runnable model"
142+
143+
144+
def load_class_from_file(file_path: str, class_name: str):
145+
print(f"Load {class_name} from {file_path}")
146+
module = load_module(file_path, "unnamed_graph_module")
147+
model_class = getattr(module, class_name, None)
148+
return model_class
149+
150+
151+
class AstGraphRenamer(ast.NodeTransformer):
152+
def __init__(self, rename_map, input_arg_names, weight_arg_names):
153+
self.rename_map = rename_map
154+
self.input_and_weight_arg_names = set(input_arg_names) | set(weight_arg_names)
155+
self.counters = {"tmp": 0}
156+
self.in_forward = False
157+
158+
def visit_FunctionDef(self, node):
159+
if node.name != "forward":
160+
return node
161+
self.in_forward = True
162+
node.args.args = self._rename_function_args(node.args.args)
163+
node.body = self._rename_function_body(node.body)
164+
self.in_forward = False
165+
return node
166+
167+
def _rename_function_args(self, args):
168+
new_function_args = []
169+
for arg in args:
170+
if arg.arg == "self":
171+
new_function_args.append(arg)
172+
else:
173+
new_function_args.append(self._create_renamed_arg(arg))
174+
return new_function_args
175+
176+
def _create_renamed_arg(self, arg):
177+
if arg.arg in self.rename_map:
178+
return ast.arg(arg=self.rename_map[arg.arg], annotation=arg.annotation)
179+
return arg
180+
181+
def _rename_function_body(self, body):
182+
new_function_body = []
183+
for stmt in body:
184+
stmt = self._remove_clear_stmt_of_args(stmt)
185+
if stmt:
186+
stmt = self.visit(stmt)
187+
new_function_body.append(stmt)
188+
return new_function_body
189+
190+
def _remove_clear_stmt_of_args(self, stmt):
191+
# remove stmt like w_0 = None
192+
if self._is_assign_none(stmt):
193+
return self._clean_assign_none(stmt)
194+
# remove stmt like del w_0
195+
elif isinstance(stmt, ast.Delete):
196+
return self._clean_delete(stmt)
197+
else:
198+
pass
199+
return stmt
200+
201+
def _is_assign_none(self, stmt):
202+
return (
203+
isinstance(stmt, ast.Assign)
204+
and isinstance(stmt.value, ast.Constant)
205+
and stmt.value.value is None
206+
)
207+
208+
def _clean_assign_none(self, stmt):
209+
new_targets = [t for t in stmt.targets if not self._is_input_or_weight_var(t)]
210+
if not new_targets:
211+
return None
212+
stmt.targets = new_targets
213+
return stmt
214+
215+
def _is_input_or_weight_var(self, target):
216+
return (
217+
isinstance(target, ast.Name)
218+
and target.id in self.input_and_weight_arg_names
219+
)
220+
221+
def _clean_delete(self, stmt):
222+
new_targets = []
223+
for target in stmt.targets:
224+
kept = self._filter_delete_target(target)
225+
if kept:
226+
new_targets.append(kept)
227+
228+
if not new_targets:
229+
return None
230+
stmt.targets = new_targets
231+
return stmt
232+
233+
def _filter_delete_target(self, target):
234+
if isinstance(target, ast.Tuple): # del (a, b)
235+
kept_elts = [e for e in target.elts if not self._is_protected_var(e)]
236+
return ast.Tuple(elts=kept_elts, ctx=ast.Del()) if kept_elts else None
237+
elif not self._is_protected_var(target): # del a
238+
return target
239+
else:
240+
pass
241+
return None
242+
243+
def visit_Assign(self, node):
244+
if not self.in_forward:
245+
return node
246+
self._register_new_local_variables(node.targets)
247+
self.generic_visit(node)
248+
return node
249+
250+
def _register_new_local_variables(self, targets):
251+
for target in targets:
252+
for name in self._flatten_assignment_target(target):
253+
self._register_if_unknown(name)
254+
255+
def _flatten_assignment_target(self, target):
256+
if isinstance(target, ast.Name):
257+
yield target.id
258+
elif isinstance(target, (ast.Tuple, ast.List)):
259+
for elt in target.elts:
260+
yield from self._flatten_assignment_target(elt)
261+
else:
262+
pass
263+
264+
def _register_if_unknown(self, name):
265+
if name not in self.rename_map:
266+
new_name = f"tmp_{self.counters['tmp']}"
267+
self.counters["tmp"] += 1
268+
self.rename_map[name] = new_name
269+
270+
def visit_Name(self, node):
271+
if not self.in_forward:
272+
return node
273+
if node.id in self.rename_map:
274+
return ast.Name(id=self.rename_map[node.id], ctx=node.ctx)
275+
return node
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
RENAMED_PATH=/tmp/ast_graph_variable_rename_workspace
5+
6+
mkdir -p "$RENAMED_PATH"
7+
model_list="$GRAPH_NET_ROOT/graph_net/config/small100_torch_samples_list.txt"
8+
9+
python3 -m graph_net.model_path_handler \
10+
--model-path-list $model_list \
11+
--handler-config=$(base64 -w 0 <<EOF
12+
{
13+
"handler_path": "$GRAPH_NET_ROOT/graph_net/sample_pass/ast_graph_variable_renamer.py",
14+
"handler_class_name": "AstGraphVariableRenamer",
15+
"handler_config": {
16+
"device": "cuda",
17+
"resume": true,
18+
"try_run": true,
19+
"model_path_prefix": "$GRAPH_NET_ROOT/",
20+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
21+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
22+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
23+
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
24+
"output_dir": "$RENAMED_PATH"
25+
}
26+
}
27+
EOF
28+
) \
29+
2>&1 | tee "$RENAMED_PATH/graph_rename.log"
30+
31+
python3 -m graph_net.torch.test_compiler \
32+
--model-path-prefix $GRAPH_NET_ROOT \
33+
--allow-list $model_list \
34+
--compiler graph_variable_renamer_validator \
35+
--device cuda \
36+
--config $(base64 -w 0 <<EOF
37+
{
38+
"model_path_prefix": "$GRAPH_NET_ROOT",
39+
"renamed_root": "$RENAMED_PATH"
40+
}
41+
EOF
42+
) \
43+
2>&1 | tee "$RENAMED_PATH/validation.log"
44+
45+
python3 -m graph_net.plot_ESt \
46+
--benchmark-path "$RENAMED_PATH/validation.log" \
47+
--output-dir "$RENAMED_PATH"

graph_net/test/graph_variable_rename_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ python3 -m graph_net.model_path_handler \
1515
"handler_config": {
1616
"device": "cuda",
1717
"resume": true,
18+
"try_run": true,
1819
"model_path_prefix": "$GRAPH_NET_ROOT/",
1920
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
2021
"data_input_predicator_class_name": "NaiveDataInputPredicator",

graph_net/tools/dimension_symbolizer.sh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
#!/bin/bash
22

3+
# GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
# os.path.dirname(os.path.dirname(graph_net.__file__)))")
5+
6+
# python3 -m graph_net.model_path_handler \
7+
# --model-path-list $GRAPH_NET_ROOT/graph_net/config/small_sample_list_for_get_fusible_subgraph.txt \
8+
# --handler-config=$(base64 -w 0 <<EOF
9+
# {
10+
# "handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/dimension_symbolizer.py",
11+
# "handler_class_name": "DimensionSymbolizer",
12+
# "handler_config": {
13+
# "resume": false,
14+
# "output_dir": "/tmp/workspace_dimension_symbolizer",
15+
# "model_path_prefix": "$GRAPH_NET_ROOT",
16+
# "limits_handled_models": 10,
17+
# "last_model_log_file": "/tmp/a.py"
18+
# }
19+
# }
20+
# EOF
21+
# )
22+
323
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
424
os.path.dirname(os.path.dirname(graph_net.__file__)))")
525

0 commit comments

Comments
 (0)