Skip to content

Commit 90b622d

Browse files
committed
[Feature Enhancement] Rename torch graph variable
1 parent bea41c2 commit 90b622d

File tree

2 files changed

+225
-0
lines changed

2 files changed

+225
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
MODEL_NAME=resnet18
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
config_json_str=$(cat <<EOF
10+
{
11+
"handler_path": "$GRAPH_NET_ROOT/torch/graph_variable_renamer.py",
12+
"handler_class_name": "GraphVariableRenamer",
13+
"handler_config": {
14+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
15+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
16+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
17+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
18+
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
19+
"output_dir": "/tmp/graph_variable_rename_workspace"
20+
}
21+
}
22+
EOF
23+
)
24+
CONFIG=$(echo $config_json_str | base64 -w 0)
25+
26+
# python3 -m graph_net.model_path_handler --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
27+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import os
2+
import torch
3+
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
4+
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
5+
from graph_net.tensor_meta import TensorMeta
6+
from pathlib import Path
7+
import shutil
8+
from graph_net.torch.utils import apply_templates
9+
from graph_net.imp_util import load_module
10+
import inspect
11+
12+
13+
class GraphVariableRenamer:
14+
"""
15+
Used by graph_net.model_path_handler
16+
"""
17+
18+
def __init__(self, config: dict = None):
19+
if config is None:
20+
config = {}
21+
self.config = self._make_config(**config)
22+
self.data_input_predicator = self._make_data_input_predicator(self.config)
23+
self.model_runnable_predicator = self._make_model_runnable_predicator(
24+
self.config
25+
)
26+
27+
def _make_data_input_predicator(self, config):
28+
module = load_module(config["data_input_predicator_filepath"])
29+
cls = getattr(module, config["data_input_predicator_class_name"])
30+
return cls(config["data_input_predicator_config"])
31+
32+
def _make_model_runnable_predicator(self, config):
33+
module = load_module(config["model_runnable_predicator_filepath"])
34+
cls = getattr(module, config["model_runnable_predicator_class_name"])
35+
return cls(config["model_runnable_predicator_config"])
36+
37+
def _make_config(
38+
self,
39+
data_input_predicator_filepath,
40+
model_runnable_predicator_filepath,
41+
output_dir="./tmp/graph_variable_renamer_dir",
42+
filter_path=None,
43+
filter_config=None,
44+
post_extract_process_path=None,
45+
post_extract_process_class_name=None,
46+
post_extract_process_config=None,
47+
data_input_predicator_class_name="DataInputPredicator",
48+
model_runnable_predicator_class_name="ModelRunner",
49+
data_input_predicator_config=None,
50+
model_runnable_predicator_config=None,
51+
model_path_prefix="",
52+
**kwargs,
53+
):
54+
if post_extract_process_config is None:
55+
post_extract_process_config = {}
56+
if data_input_predicator_config is None:
57+
data_input_predicator_config = {}
58+
if model_runnable_predicator_config is None:
59+
model_runnable_predicator_config = {}
60+
return {
61+
"output_dir": output_dir,
62+
"filter_path": filter_path,
63+
"filter_config": filter_config if filter_config is not None else {},
64+
"post_extract_process_path": post_extract_process_path,
65+
"post_extract_process_class_name": post_extract_process_class_name,
66+
"post_extract_process_config": post_extract_process_config,
67+
"data_input_predicator_filepath": data_input_predicator_filepath,
68+
"data_input_predicator_class_name": data_input_predicator_class_name,
69+
"data_input_predicator_config": data_input_predicator_config,
70+
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
71+
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
72+
"model_runnable_predicator_config": model_runnable_predicator_config,
73+
"model_path_prefix": model_path_prefix,
74+
}
75+
76+
def __call__(self, rel_model_path):
77+
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
78+
module, inputs = get_torch_module_and_inputs(src_model_path)
79+
gm = parse_sole_graph_module(module, inputs)
80+
gm = self.rename_graph_variables(gm, inputs, src_model_path)
81+
# print(gm)
82+
dst_model_path = os.path.join(self.config["output_dir"], rel_model_path)
83+
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
84+
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
85+
self._update_model_py_file(gm, dst_model_path)
86+
self._update_weight_meta_py_file(src_model_path, dst_model_path)
87+
self._update_input_meta_py_file(src_model_path, dst_model_path)
88+
self._try_run(dst_model_path)
89+
90+
def _try_run(self, model_path):
91+
assert self.model_runnable_predicator(
92+
model_path
93+
), f"{model_path} is not a runnable model"
94+
95+
def _update_model_py_file(self, graph_module, model_path):
96+
py_code = apply_templates(graph_module.code)
97+
(Path(model_path) / "model.py").write_text(py_code)
98+
99+
def _update_weight_meta_py_file(self, src_model_path, dst_model_path):
100+
old_name_to_new_name = self._get_original_name_to_new_name(
101+
src_model_path, dst_model_path
102+
)
103+
tensor_metas = TensorMeta.unserialize_from_py_file(
104+
os.path.join(src_model_path, "weight_meta.py"),
105+
)
106+
for weight_meta in tensor_metas:
107+
assert weight_meta.name in old_name_to_new_name
108+
if weight_meta.original_name is None:
109+
weight_meta.original_name = weight_meta.name
110+
weight_meta.name = old_name_to_new_name[weight_meta.name]
111+
py_code = "\n\n".join(
112+
[weight_meta.serialize_to_py_str() for weight_meta in tensor_metas]
113+
)
114+
(Path(dst_model_path) / "weight_meta.py").write_text(py_code)
115+
116+
def _update_input_meta_py_file(self, src_model_path, dst_model_path):
117+
old_name_to_new_name = self._get_original_name_to_new_name(
118+
src_model_path, dst_model_path
119+
)
120+
tensor_metas = TensorMeta.unserialize_from_py_file(
121+
os.path.join(src_model_path, "input_meta.py"),
122+
)
123+
for input_meta in tensor_metas:
124+
assert input_meta.name in old_name_to_new_name
125+
if input_meta.original_name is None:
126+
input_meta.original_name = input_meta.name
127+
input_meta.name = old_name_to_new_name[input_meta.name]
128+
py_code = "\n\n".join(
129+
[input_meta.serialize_to_py_str() for input_meta in tensor_metas]
130+
)
131+
(Path(dst_model_path) / "input_meta.py").write_text(py_code)
132+
133+
def _get_original_name_to_new_name(self, src_model_path, dst_model_path):
134+
src_model = self._get_model(src_model_path)
135+
dst_model = self._get_model(dst_model_path)
136+
old_name_and_new_name_pairs = zip(
137+
self._get_input_names_from_signature(src_model),
138+
self._get_input_names_from_signature(dst_model),
139+
strict=True,
140+
)
141+
return {
142+
old_name: new_name for old_name, new_name in old_name_and_new_name_pairs
143+
}
144+
145+
def _get_model(self, model_path):
146+
py_module = load_module(os.path.join(model_path, "model.py"))
147+
GraphModule = getattr(py_module, "GraphModule")
148+
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
149+
return GraphModule()
150+
151+
def _get_input_names_from_signature(self, module):
152+
return inspect.signature(module.forward).parameters
153+
154+
def rename_graph_variables(
155+
self, gm: torch.fx.GraphModule, sample_inputs, model_path
156+
):
157+
in_cnt = 0
158+
w_cnt = 0
159+
tmp_cnt = 0
160+
161+
arg_iter = iter(sample_inputs)
162+
for node in gm.graph.nodes:
163+
if "original_name" not in node.meta:
164+
node.meta["original_name"] = node.name
165+
166+
if node.op == "placeholder":
167+
real_arg = next(arg_iter)
168+
is_weight = not self.data_input_predicator(model_path, node.name)
169+
if node.type is not None:
170+
if isinstance(node.type, type) and issubclass(
171+
node.type, torch.nn.parameter.Parameter
172+
):
173+
is_weight = True
174+
elif real_arg is not None:
175+
if isinstance(real_arg, torch.nn.Parameter):
176+
is_weight = True
177+
178+
if is_weight:
179+
new_name = f"w_{w_cnt}"
180+
w_cnt += 1
181+
else:
182+
new_name = f"in_{in_cnt}"
183+
in_cnt += 1
184+
185+
node.name = new_name
186+
node.target = new_name
187+
188+
elif node.op == "get_attr":
189+
node.name = f"w_{w_cnt}"
190+
w_cnt += 1
191+
192+
elif node.op != "output":
193+
node.name = f"tmp_{tmp_cnt}"
194+
tmp_cnt += 1
195+
196+
gm.graph.lint()
197+
gm.recompile()
198+
return gm

0 commit comments

Comments
 (0)