Skip to content

Commit 9b9e4b4

Browse files
authored
Fix: Introduce rename_map to ensure correct graph variable mapping (#457)
1 parent cb5d2b2 commit 9b9e4b4

File tree

1 file changed

+40
-39
lines changed

1 file changed

+40
-39
lines changed

graph_net/torch/graph_variable_renamer.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import torch
33
import shutil
4-
import inspect
54
import tempfile
65
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
76
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
@@ -81,7 +80,7 @@ def __call__(self, rel_model_path):
8180
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
8281
module, inputs = get_torch_module_and_inputs(src_model_path)
8382
gm = parse_sole_graph_module(module, inputs)
84-
gm = self.rename_graph_variables(gm, inputs, src_model_path)
83+
gm, rename_map = self.rename_graph_variables(gm, inputs, src_model_path)
8584
dst_model_path = os.path.realpath(
8685
os.path.join(self.config["output_dir"], rel_model_path)
8786
)
@@ -94,8 +93,10 @@ def __call__(self, rel_model_path):
9493
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
9594
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
9695
self._update_model_py_file(gm, temp_model_path)
97-
self._update_weight_meta_py_file(src_model_path, temp_model_path)
98-
self._update_input_meta_py_file(src_model_path, temp_model_path)
96+
self._update_weight_meta_py_file(
97+
src_model_path, temp_model_path, rename_map
98+
)
99+
self._update_input_meta_py_file(src_model_path, temp_model_path, rename_map)
99100
print("Try to run renamed model...")
100101
self._try_run(temp_model_path)
101102
shutil.copytree(temp_model_path, dst_model_path)
@@ -111,99 +112,99 @@ def _update_model_py_file(self, graph_module, model_path):
111112
file_hash = get_sha256_hash(py_code)
112113
(Path(model_path) / "graph_hash.txt").write_text(file_hash)
113114

114-
def _update_weight_meta_py_file(self, src_model_path, dst_model_path):
115-
old_name_to_new_name = self._get_original_name_to_new_name(
116-
src_model_path, dst_model_path
117-
)
115+
def _update_weight_meta_py_file(self, src_model_path, dst_model_path, rename_map):
118116
tensor_metas = TensorMeta.unserialize_from_py_file(
119117
os.path.join(src_model_path, "weight_meta.py"),
120118
)
121119
for weight_meta in tensor_metas:
122-
assert weight_meta.name in old_name_to_new_name
120+
meta_name = self._find_name_in_rename_map(weight_meta.name, rename_map)
121+
assert (
122+
meta_name is not None
123+
), f"{weight_meta.name} is not found in rename_map"
123124
if weight_meta.original_name is None:
124125
weight_meta.original_name = weight_meta.name
125-
weight_meta.name = old_name_to_new_name[weight_meta.name]
126+
weight_meta.name = rename_map[meta_name]
127+
126128
py_code = "\n\n".join(
127129
[weight_meta.serialize_to_py_str() for weight_meta in tensor_metas]
128130
)
129131
(Path(dst_model_path) / "weight_meta.py").write_text(py_code)
130132

131-
def _update_input_meta_py_file(self, src_model_path, dst_model_path):
132-
old_name_to_new_name = self._get_original_name_to_new_name(
133-
src_model_path, dst_model_path
134-
)
133+
def _update_input_meta_py_file(self, src_model_path, dst_model_path, rename_map):
135134
tensor_metas = TensorMeta.unserialize_from_py_file(
136135
os.path.join(src_model_path, "input_meta.py"),
137136
)
138137
for input_meta in tensor_metas:
139-
assert input_meta.name in old_name_to_new_name
138+
meta_name = self._find_name_in_rename_map(input_meta.name, rename_map)
139+
assert (
140+
meta_name is not None
141+
), f"{input_meta.name} is not found in rename_map"
140142
if input_meta.original_name is None:
141143
input_meta.original_name = input_meta.name
142-
input_meta.name = old_name_to_new_name[input_meta.name]
144+
input_meta.name = rename_map[meta_name]
145+
143146
py_code = "\n\n".join(
144147
[input_meta.serialize_to_py_str() for input_meta in tensor_metas]
145148
)
146149
(Path(dst_model_path) / "input_meta.py").write_text(py_code)
147150

148-
def _get_original_name_to_new_name(self, src_model_path, dst_model_path):
149-
src_model = self._get_model(src_model_path)
150-
dst_model = self._get_model(dst_model_path)
151-
old_name_and_new_name_pairs = zip(
152-
self._get_input_names_from_signature(src_model),
153-
self._get_input_names_from_signature(dst_model),
154-
strict=True,
155-
)
156-
return {
157-
old_name: new_name for old_name, new_name in old_name_and_new_name_pairs
158-
}
151+
def _find_name_in_rename_map(self, raw_name, rename_map):
152+
if raw_name in rename_map:
153+
return raw_name
154+
# s1 -> s1_
155+
elif (raw_name + "_") in rename_map:
156+
return raw_name + "_"
157+
else:
158+
return None
159159

160160
def _get_model(self, model_path):
161161
py_module = load_module(os.path.join(model_path, "model.py"))
162162
GraphModule = getattr(py_module, "GraphModule")
163163
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
164164
return GraphModule()
165165

166-
def _get_input_names_from_signature(self, module):
167-
return inspect.signature(module.forward).parameters
168-
169166
def rename_graph_variables(
170167
self, gm: torch.fx.GraphModule, sample_inputs, model_path
171168
):
172169
counters = {"in": 0, "w": 0, "tmp": 0}
170+
rename_map = {}
173171
# graph may not have input, only contain weights
174172
arg_iter = iter(sample_inputs) if sample_inputs else iter([])
175173
for node in gm.graph.nodes:
176-
self._process_single_node(node, arg_iter, counters, model_path)
174+
self._process_single_node(node, arg_iter, counters, model_path, rename_map)
177175
gm.graph.lint()
178176
gm.recompile()
179-
return gm
177+
return gm, rename_map
180178

181-
def _process_single_node(self, node, arg_iter, counters, model_path):
179+
def _process_single_node(self, node, arg_iter, counters, model_path, rename_map):
182180
if "original_name" not in node.meta:
183181
node.meta["original_name"] = node.name
184182
if node.op == "placeholder":
185-
self._handle_placeholder(node, arg_iter, counters, model_path)
183+
self._handle_placeholder(node, arg_iter, counters, model_path, rename_map)
186184
elif node.op == "get_attr":
187-
self._apply_rename(node, "w", counters)
185+
self._apply_rename(node, "w", counters, rename_map)
188186
elif node.op != "output":
189-
self._apply_rename(node, "tmp", counters)
187+
self._apply_rename(node, "tmp", counters, rename_map)
190188
else:
191189
# Do nothing
192190
pass
193191

194-
def _handle_placeholder(self, node, arg_iter, counters, model_path):
192+
def _handle_placeholder(self, node, arg_iter, counters, model_path, rename_map):
195193
real_arg = next(arg_iter, None)
196194
is_weight = self._is_weight_node(node, real_arg, model_path)
197195
prefix = "w" if is_weight else "in"
198-
self._apply_rename(node, prefix, counters, update_target=True)
196+
self._apply_rename(node, prefix, counters, rename_map, update_target=True)
199197

200-
def _apply_rename(self, node, prefix, counters, update_target=False):
198+
def _apply_rename(self, node, prefix, counters, rename_map, update_target=False):
199+
old_name = node.name
201200
new_name = f"{prefix}_{counters[prefix]}"
202201
counters[prefix] += 1
203202
node.name = new_name
204203
if update_target:
205204
node.target = new_name
206205

206+
rename_map[old_name] = new_name
207+
207208
def _is_weight_node(self, node, real_arg, model_path):
208209
is_not_data_input = not self.data_input_predicator(model_path, node.name)
209210
is_parameter_type = (

0 commit comments

Comments
 (0)