Skip to content

Commit d83dc33

Browse files
committed
Merge
2 parents 2a49397 + 9b9e4b4 commit d83dc33

File tree

1 file changed

+40
-39
lines changed

1 file changed

+40
-39
lines changed

graph_net/torch/graph_variable_renamer.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import torch
33
import shutil
4-
import inspect
54
import tempfile
65

76
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
@@ -92,16 +91,18 @@ def __call__(self, rel_model_path):
9291
with cuda_gc(enabled=self.config["release_gpu_memory"]):
9392
module, inputs = get_torch_module_and_inputs(src_model_path)
9493
gm = parse_sole_graph_module(module, inputs)
95-
gm = self.rename_graph_variables(gm, inputs, src_model_path)
94+
gm, rename_map = self.rename_graph_variables(gm, inputs, src_model_path)
9695
del module, inputs
9796

9897
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
9998
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
10099
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
101100
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
102101
self._update_model_py_file(gm, temp_model_path)
103-
self._update_weight_meta_py_file(src_model_path, temp_model_path)
104-
self._update_input_meta_py_file(src_model_path, temp_model_path)
102+
self._update_weight_meta_py_file(
103+
src_model_path, temp_model_path, rename_map
104+
)
105+
self._update_input_meta_py_file(src_model_path, temp_model_path, rename_map)
105106
print("Try to run renamed model...")
106107
self._try_run(temp_model_path)
107108
shutil.copytree(temp_model_path, dst_model_path)
@@ -117,99 +118,99 @@ def _update_model_py_file(self, graph_module, model_path):
117118
file_hash = get_sha256_hash(py_code)
118119
(Path(model_path) / "graph_hash.txt").write_text(file_hash)
119120

120-
def _update_weight_meta_py_file(self, src_model_path, dst_model_path):
121-
old_name_to_new_name = self._get_original_name_to_new_name(
122-
src_model_path, dst_model_path
123-
)
121+
def _update_weight_meta_py_file(self, src_model_path, dst_model_path, rename_map):
124122
tensor_metas = TensorMeta.unserialize_from_py_file(
125123
os.path.join(src_model_path, "weight_meta.py"),
126124
)
127125
for weight_meta in tensor_metas:
128-
assert weight_meta.name in old_name_to_new_name
126+
meta_name = self._find_name_in_rename_map(weight_meta.name, rename_map)
127+
assert (
128+
meta_name is not None
129+
), f"{weight_meta.name} is not found in rename_map"
129130
if weight_meta.original_name is None:
130131
weight_meta.original_name = weight_meta.name
131-
weight_meta.name = old_name_to_new_name[weight_meta.name]
132+
weight_meta.name = rename_map[meta_name]
133+
132134
py_code = "\n\n".join(
133135
[weight_meta.serialize_to_py_str() for weight_meta in tensor_metas]
134136
)
135137
(Path(dst_model_path) / "weight_meta.py").write_text(py_code)
136138

137-
def _update_input_meta_py_file(self, src_model_path, dst_model_path):
138-
old_name_to_new_name = self._get_original_name_to_new_name(
139-
src_model_path, dst_model_path
140-
)
139+
def _update_input_meta_py_file(self, src_model_path, dst_model_path, rename_map):
141140
tensor_metas = TensorMeta.unserialize_from_py_file(
142141
os.path.join(src_model_path, "input_meta.py"),
143142
)
144143
for input_meta in tensor_metas:
145-
assert input_meta.name in old_name_to_new_name
144+
meta_name = self._find_name_in_rename_map(input_meta.name, rename_map)
145+
assert (
146+
meta_name is not None
147+
), f"{input_meta.name} is not found in rename_map"
146148
if input_meta.original_name is None:
147149
input_meta.original_name = input_meta.name
148-
input_meta.name = old_name_to_new_name[input_meta.name]
150+
input_meta.name = rename_map[meta_name]
151+
149152
py_code = "\n\n".join(
150153
[input_meta.serialize_to_py_str() for input_meta in tensor_metas]
151154
)
152155
(Path(dst_model_path) / "input_meta.py").write_text(py_code)
153156

154-
def _get_original_name_to_new_name(self, src_model_path, dst_model_path):
155-
src_model = self._get_model(src_model_path)
156-
dst_model = self._get_model(dst_model_path)
157-
old_name_and_new_name_pairs = zip(
158-
self._get_input_names_from_signature(src_model),
159-
self._get_input_names_from_signature(dst_model),
160-
strict=True,
161-
)
162-
return {
163-
old_name: new_name for old_name, new_name in old_name_and_new_name_pairs
164-
}
157+
def _find_name_in_rename_map(self, raw_name, rename_map):
158+
if raw_name in rename_map:
159+
return raw_name
160+
# s1 -> s1_
161+
elif (raw_name + "_") in rename_map:
162+
return raw_name + "_"
163+
else:
164+
return None
165165

166166
def _get_model(self, model_path):
167167
py_module = load_module(os.path.join(model_path, "model.py"))
168168
GraphModule = getattr(py_module, "GraphModule")
169169
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
170170
return GraphModule()
171171

172-
def _get_input_names_from_signature(self, module):
173-
return inspect.signature(module.forward).parameters
174-
175172
def rename_graph_variables(
176173
self, gm: torch.fx.GraphModule, sample_inputs, model_path
177174
):
178175
counters = {"in": 0, "w": 0, "tmp": 0}
176+
rename_map = {}
179177
# graph may not have input, only contain weights
180178
arg_iter = iter(sample_inputs) if sample_inputs else iter([])
181179
for node in gm.graph.nodes:
182-
self._process_single_node(node, arg_iter, counters, model_path)
180+
self._process_single_node(node, arg_iter, counters, model_path, rename_map)
183181
gm.graph.lint()
184182
gm.recompile()
185-
return gm
183+
return gm, rename_map
186184

187-
def _process_single_node(self, node, arg_iter, counters, model_path):
185+
def _process_single_node(self, node, arg_iter, counters, model_path, rename_map):
188186
if "original_name" not in node.meta:
189187
node.meta["original_name"] = node.name
190188
if node.op == "placeholder":
191-
self._handle_placeholder(node, arg_iter, counters, model_path)
189+
self._handle_placeholder(node, arg_iter, counters, model_path, rename_map)
192190
elif node.op == "get_attr":
193-
self._apply_rename(node, "w", counters)
191+
self._apply_rename(node, "w", counters, rename_map)
194192
elif node.op != "output":
195-
self._apply_rename(node, "tmp", counters)
193+
self._apply_rename(node, "tmp", counters, rename_map)
196194
else:
197195
# Do nothing
198196
pass
199197

200-
def _handle_placeholder(self, node, arg_iter, counters, model_path):
198+
def _handle_placeholder(self, node, arg_iter, counters, model_path, rename_map):
201199
real_arg = next(arg_iter, None)
202200
is_weight = self._is_weight_node(node, real_arg, model_path)
203201
prefix = "w" if is_weight else "in"
204-
self._apply_rename(node, prefix, counters, update_target=True)
202+
self._apply_rename(node, prefix, counters, rename_map, update_target=True)
205203

206-
def _apply_rename(self, node, prefix, counters, update_target=False):
204+
def _apply_rename(self, node, prefix, counters, rename_map, update_target=False):
205+
old_name = node.name
207206
new_name = f"{prefix}_{counters[prefix]}"
208207
counters[prefix] += 1
209208
node.name = new_name
210209
if update_target:
211210
node.target = new_name
212211

212+
rename_map[old_name] = new_name
213+
213214
def _is_weight_node(self, node, real_arg, model_path):
214215
is_not_data_input = not self.data_input_predicator(model_path, node.name)
215216
is_parameter_type = (

0 commit comments

Comments
 (0)