Skip to content

Commit 8e928e2

Browse files
authored
[Feature Enhancement] Rename torch graph variable. (#428)
* [Feature Enhancement] Rename torch graph variable * fix
1 parent eef75f1 commit 8e928e2

File tree

3 files changed

+231
-0
lines changed

3 files changed

+231
-0
lines changed

graph_net/tensor_meta.py

100644100755
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def serialize_to_py_str(self) -> str:
5858
lines = [
5959
(f"class {self.record_class_name}:"),
6060
(f'\tname = "{self.name}"'),
61+
*(
62+
[f'\toriginal_name = "{self.original_name}"']
63+
if self.original_name is not None
64+
else []
65+
),
6166
(f"\tshape = {self.shape}"),
6267
(f'\tdtype = "{self.dtype}"'),
6368
(f'\tdevice = "{self.device}"'),
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
MODEL_NAME=resnet18
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
config_json_str=$(cat <<EOF
10+
{
11+
"handler_path": "$GRAPH_NET_ROOT/torch/graph_variable_renamer.py",
12+
"handler_class_name": "GraphVariableRenamer",
13+
"handler_config": {
14+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
15+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
16+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
17+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
18+
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
19+
"output_dir": "/tmp/graph_variable_rename_workspace"
20+
}
21+
}
22+
EOF
23+
)
24+
CONFIG=$(echo $config_json_str | base64 -w 0)
25+
26+
python3 -m graph_net.model_path_handler --model-path samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
27+
# python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import os
2+
import torch
3+
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
4+
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
5+
from graph_net.tensor_meta import TensorMeta
6+
from pathlib import Path
7+
import shutil
8+
from graph_net.torch.utils import apply_templates
9+
from graph_net.imp_util import load_module
10+
import inspect
11+
12+
13+
class GraphVariableRenamer:
14+
"""
15+
Used by graph_net.model_path_handler
16+
"""
17+
18+
def __init__(self, config: dict = None):
19+
if config is None:
20+
config = {}
21+
self.config = self._make_config(**config)
22+
self.data_input_predicator = self._make_data_input_predicator(self.config)
23+
self.model_runnable_predicator = self._make_model_runnable_predicator(
24+
self.config
25+
)
26+
27+
def _make_data_input_predicator(self, config):
28+
module = load_module(config["data_input_predicator_filepath"])
29+
cls = getattr(module, config["data_input_predicator_class_name"])
30+
return cls(config["data_input_predicator_config"])
31+
32+
def _make_model_runnable_predicator(self, config):
33+
module = load_module(config["model_runnable_predicator_filepath"])
34+
cls = getattr(module, config["model_runnable_predicator_class_name"])
35+
return cls(config["model_runnable_predicator_config"])
36+
37+
def _make_config(
38+
self,
39+
data_input_predicator_filepath,
40+
model_runnable_predicator_filepath,
41+
output_dir="./tmp/graph_variable_renamer_dir",
42+
filter_path=None,
43+
filter_config=None,
44+
post_extract_process_path=None,
45+
post_extract_process_class_name=None,
46+
post_extract_process_config=None,
47+
data_input_predicator_class_name="DataInputPredicator",
48+
model_runnable_predicator_class_name="ModelRunner",
49+
data_input_predicator_config=None,
50+
model_runnable_predicator_config=None,
51+
model_path_prefix="",
52+
**kwargs,
53+
):
54+
if post_extract_process_config is None:
55+
post_extract_process_config = {}
56+
if data_input_predicator_config is None:
57+
data_input_predicator_config = {}
58+
if model_runnable_predicator_config is None:
59+
model_runnable_predicator_config = {}
60+
return {
61+
"output_dir": output_dir,
62+
"filter_path": filter_path,
63+
"filter_config": filter_config if filter_config is not None else {},
64+
"post_extract_process_path": post_extract_process_path,
65+
"post_extract_process_class_name": post_extract_process_class_name,
66+
"post_extract_process_config": post_extract_process_config,
67+
"data_input_predicator_filepath": data_input_predicator_filepath,
68+
"data_input_predicator_class_name": data_input_predicator_class_name,
69+
"data_input_predicator_config": data_input_predicator_config,
70+
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
71+
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
72+
"model_runnable_predicator_config": model_runnable_predicator_config,
73+
"model_path_prefix": model_path_prefix,
74+
}
75+
76+
def __call__(self, rel_model_path):
77+
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
78+
module, inputs = get_torch_module_and_inputs(src_model_path)
79+
gm = parse_sole_graph_module(module, inputs)
80+
gm = self.rename_graph_variables(gm, inputs, src_model_path)
81+
dst_model_path = os.path.realpath(
82+
os.path.join(self.config["output_dir"], rel_model_path)
83+
)
84+
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
85+
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
86+
self._update_model_py_file(gm, dst_model_path)
87+
self._update_weight_meta_py_file(src_model_path, dst_model_path)
88+
self._update_input_meta_py_file(src_model_path, dst_model_path)
89+
self._try_run(dst_model_path)
90+
91+
def _try_run(self, model_path):
92+
assert self.model_runnable_predicator(
93+
model_path
94+
), f"{model_path} is not a runnable model"
95+
96+
def _update_model_py_file(self, graph_module, model_path):
97+
py_code = apply_templates(graph_module.code)
98+
(Path(model_path) / "model.py").write_text(py_code)
99+
100+
def _update_weight_meta_py_file(self, src_model_path, dst_model_path):
101+
old_name_to_new_name = self._get_original_name_to_new_name(
102+
src_model_path, dst_model_path
103+
)
104+
tensor_metas = TensorMeta.unserialize_from_py_file(
105+
os.path.join(src_model_path, "weight_meta.py"),
106+
)
107+
for weight_meta in tensor_metas:
108+
assert weight_meta.name in old_name_to_new_name
109+
if weight_meta.original_name is None:
110+
weight_meta.original_name = weight_meta.name
111+
weight_meta.name = old_name_to_new_name[weight_meta.name]
112+
py_code = "\n\n".join(
113+
[weight_meta.serialize_to_py_str() for weight_meta in tensor_metas]
114+
)
115+
(Path(dst_model_path) / "weight_meta.py").write_text(py_code)
116+
117+
def _update_input_meta_py_file(self, src_model_path, dst_model_path):
118+
old_name_to_new_name = self._get_original_name_to_new_name(
119+
src_model_path, dst_model_path
120+
)
121+
tensor_metas = TensorMeta.unserialize_from_py_file(
122+
os.path.join(src_model_path, "input_meta.py"),
123+
)
124+
for input_meta in tensor_metas:
125+
assert input_meta.name in old_name_to_new_name
126+
if input_meta.original_name is None:
127+
input_meta.original_name = input_meta.name
128+
input_meta.name = old_name_to_new_name[input_meta.name]
129+
py_code = "\n\n".join(
130+
[input_meta.serialize_to_py_str() for input_meta in tensor_metas]
131+
)
132+
(Path(dst_model_path) / "input_meta.py").write_text(py_code)
133+
134+
def _get_original_name_to_new_name(self, src_model_path, dst_model_path):
135+
src_model = self._get_model(src_model_path)
136+
dst_model = self._get_model(dst_model_path)
137+
old_name_and_new_name_pairs = zip(
138+
self._get_input_names_from_signature(src_model),
139+
self._get_input_names_from_signature(dst_model),
140+
strict=True,
141+
)
142+
return {
143+
old_name: new_name for old_name, new_name in old_name_and_new_name_pairs
144+
}
145+
146+
def _get_model(self, model_path):
147+
py_module = load_module(os.path.join(model_path, "model.py"))
148+
GraphModule = getattr(py_module, "GraphModule")
149+
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
150+
return GraphModule()
151+
152+
def _get_input_names_from_signature(self, module):
153+
return inspect.signature(module.forward).parameters
154+
155+
def rename_graph_variables(
156+
self, gm: torch.fx.GraphModule, sample_inputs, model_path
157+
):
158+
in_cnt = 0
159+
w_cnt = 0
160+
tmp_cnt = 0
161+
162+
arg_iter = iter(sample_inputs)
163+
for node in gm.graph.nodes:
164+
if "original_name" not in node.meta:
165+
node.meta["original_name"] = node.name
166+
167+
if node.op == "placeholder":
168+
real_arg = next(arg_iter)
169+
is_weight = not self.data_input_predicator(model_path, node.name)
170+
if node.type is not None:
171+
if isinstance(node.type, type) and issubclass(
172+
node.type, torch.nn.parameter.Parameter
173+
):
174+
is_weight = True
175+
elif real_arg is not None:
176+
if isinstance(real_arg, torch.nn.Parameter):
177+
is_weight = True
178+
179+
if is_weight:
180+
new_name = f"w_{w_cnt}"
181+
w_cnt += 1
182+
else:
183+
new_name = f"in_{in_cnt}"
184+
in_cnt += 1
185+
186+
node.name = new_name
187+
node.target = new_name
188+
189+
elif node.op == "get_attr":
190+
node.name = f"w_{w_cnt}"
191+
w_cnt += 1
192+
193+
elif node.op != "output":
194+
node.name = f"tmp_{tmp_cnt}"
195+
tmp_cnt += 1
196+
197+
gm.graph.lint()
198+
gm.recompile()
199+
return gm

0 commit comments

Comments
 (0)