Skip to content

Commit 646f4c8

Browse files
committed
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into device_rewrite_sample_pass
2 parents 138e1c0 + 017dcb3 commit 646f4c8

File tree

4 files changed

+159
-40
lines changed

4 files changed

+159
-40
lines changed

graph_net/test/graph_variable_rename_test.sh

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
44
os.path.dirname(graph_net.__file__))")
5+
WORKSPACE=/tmp/graph_variable_rename_workspace
56

67
# input model path
78
MODEL_NAME=resnet18
@@ -16,7 +17,7 @@ config_json_str=$(cat <<EOF
1617
"data_input_predicator_class_name": "NaiveDataInputPredicator",
1718
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1819
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
19-
"output_dir": "/tmp/graph_variable_rename_workspace"
20+
"output_dir": "$WORKSPACE"
2021
}
2122
}
2223
EOF
@@ -25,3 +26,23 @@ CONFIG=$(echo $config_json_str | base64 -w 0)
2526

2627
python3 -m graph_net.model_path_handler --model-path samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
2728
# python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG
29+
30+
test_compiler_config_json_str=$(cat <<EOF
31+
{
32+
"model_path_prefix": "$GRAPH_NET_ROOT",
33+
"renamed_root": "$WORKSPACE"
34+
}
35+
EOF
36+
)
37+
TEST_COMPILER_CONFIG=$(echo $test_compiler_config_json_str | base64 -w 0)
38+
39+
python3 -m graph_net.torch.test_compiler \
40+
--model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES \
41+
--compiler graph_variable_renamer_validator \
42+
--device cuda \
43+
--config $TEST_COMPILER_CONFIG \
44+
> "$WORKSPACE/validation.log" 2>&1
45+
46+
python3 -m graph_net.plot_ESt \
47+
--benchmark-path "$WORKSPACE/validation.log" \
48+
--output-dir "$WORKSPACE"
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
from pathlib import Path
3+
from typing import Dict
4+
from graph_net.tensor_meta import TensorMeta
5+
import os
6+
import importlib.util
7+
8+
9+
class RenamedModelAdapter(torch.nn.Module):
10+
def __init__(self, renamed_model: torch.nn.Module, mapping: Dict[str, str]):
11+
super().__init__()
12+
self.model = renamed_model
13+
self.mapping = mapping
14+
if hasattr(renamed_model, "__graph_net_file_path__"):
15+
self.__graph_net_file_path__ = renamed_model.__graph_net_file_path__
16+
17+
def forward(self, **kwargs):
18+
new_kwargs = self._convert_by_name_mapping(kwargs)
19+
return self.model(**new_kwargs)
20+
21+
def _convert_by_name_mapping(self, kwargs):
22+
new_kwargs = {}
23+
for old_name, value in kwargs.items():
24+
if old_name in self.mapping:
25+
new_name = self.mapping[old_name]
26+
new_kwargs[new_name] = value
27+
return new_kwargs
28+
29+
30+
class GraphVariableRenamerValidatorBackend:
31+
def _get_rename_mapping(self, model_dir: Path):
32+
mapping = {}
33+
for meta_file in ["input_meta.py", "weight_meta.py"]:
34+
meta_path = model_dir / meta_file
35+
if not meta_path.exists():
36+
continue
37+
metas = TensorMeta.unserialize_from_py_file(str(meta_path))
38+
for m in metas:
39+
if m.original_name:
40+
mapping[m.original_name] = m.name
41+
return mapping
42+
43+
def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
44+
class_name = "GraphModule"
45+
model_file = os.path.join(path, "model.py")
46+
47+
spec = importlib.util.spec_from_file_location(class_name, model_file)
48+
module = importlib.util.module_from_spec(spec)
49+
spec.loader.exec_module(module)
50+
51+
ModelClass = getattr(module, class_name)
52+
instance = ModelClass().to(device)
53+
return instance
54+
55+
def _make_config(
56+
self,
57+
model_path_prefix: str,
58+
renamed_root: str,
59+
renamed_dentry: str = "_renamed",
60+
):
61+
return {
62+
"model_path_prefix": model_path_prefix,
63+
"renamed_root": renamed_root,
64+
"renamed_dentry": renamed_dentry,
65+
}
66+
67+
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
68+
config = self._make_config(**self.config)
69+
model_path = os.path.dirname(model.__class__.__graph_net_file_path__)
70+
model_name = os.path.basename(model_path)
71+
renamed_dir_name = f"{model_name}_renamed"
72+
renamed_model_dir = os.path.join(config["renamed_root"], renamed_dir_name)
73+
74+
print(f"[GraphVariableRenamerValidatorBackend] Processing: {model_name}")
75+
print(
76+
f"[GraphVariableRenamerValidatorBackend] Loading from: {renamed_model_dir}"
77+
)
78+
79+
device = model.__class__.__graph_net_device__
80+
renamed_model = self._load_model_instance(renamed_model_dir, device)
81+
mapping = self._get_rename_mapping(Path(renamed_model_dir))
82+
assert (
83+
mapping
84+
), f"Mapping is empty for {renamed_dir_name} at {renamed_model_dir}"
85+
adapter = RenamedModelAdapter(renamed_model, mapping)
86+
return adapter.eval()
87+
88+
def synchronize(self):
89+
if torch.cuda.is_available():
90+
torch.cuda.synchronize()

graph_net/torch/graph_variable_renamer.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ def __call__(self, rel_model_path):
7979
module, inputs = get_torch_module_and_inputs(src_model_path)
8080
gm = parse_sole_graph_module(module, inputs)
8181
gm = self.rename_graph_variables(gm, inputs, src_model_path)
82+
model_name = os.path.basename(rel_model_path.rstrip(os.sep))
83+
new_rel_path = f"{model_name}_renamed"
8284
dst_model_path = os.path.realpath(
83-
os.path.join(self.config["output_dir"], rel_model_path)
85+
os.path.join(self.config["output_dir"], new_rel_path)
8486
)
8587
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
8688
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
@@ -158,45 +160,47 @@ def _get_input_names_from_signature(self, module):
158160
def rename_graph_variables(
159161
self, gm: torch.fx.GraphModule, sample_inputs, model_path
160162
):
161-
in_cnt = 0
162-
w_cnt = 0
163-
tmp_cnt = 0
164-
165-
arg_iter = iter(sample_inputs)
163+
counters = {"in": 0, "w": 0, "tmp": 0}
164+
# graph may not have input, only contain weights
165+
arg_iter = iter(sample_inputs) if sample_inputs else iter([])
166166
for node in gm.graph.nodes:
167-
if "original_name" not in node.meta:
168-
node.meta["original_name"] = node.name
169-
170-
if node.op == "placeholder":
171-
real_arg = next(arg_iter)
172-
is_weight = not self.data_input_predicator(model_path, node.name)
173-
if node.type is not None:
174-
if isinstance(node.type, type) and issubclass(
175-
node.type, torch.nn.parameter.Parameter
176-
):
177-
is_weight = True
178-
elif real_arg is not None:
179-
if isinstance(real_arg, torch.nn.Parameter):
180-
is_weight = True
181-
182-
if is_weight:
183-
new_name = f"w_{w_cnt}"
184-
w_cnt += 1
185-
else:
186-
new_name = f"in_{in_cnt}"
187-
in_cnt += 1
188-
189-
node.name = new_name
190-
node.target = new_name
191-
192-
elif node.op == "get_attr":
193-
node.name = f"w_{w_cnt}"
194-
w_cnt += 1
195-
196-
elif node.op != "output":
197-
node.name = f"tmp_{tmp_cnt}"
198-
tmp_cnt += 1
199-
167+
self._process_single_node(node, arg_iter, counters, model_path)
200168
gm.graph.lint()
201169
gm.recompile()
202170
return gm
171+
172+
def _process_single_node(self, node, arg_iter, counters, model_path):
173+
if "original_name" not in node.meta:
174+
node.meta["original_name"] = node.name
175+
if node.op == "placeholder":
176+
self._handle_placeholder(node, arg_iter, counters, model_path)
177+
elif node.op == "get_attr":
178+
self._apply_rename(node, "w", counters)
179+
elif node.op != "output":
180+
self._apply_rename(node, "tmp", counters)
181+
else:
182+
# Do nothing
183+
pass
184+
185+
def _handle_placeholder(self, node, arg_iter, counters, model_path):
186+
real_arg = next(arg_iter, None)
187+
is_weight = self._is_weight_node(node, real_arg, model_path)
188+
prefix = "w" if is_weight else "in"
189+
self._apply_rename(node, prefix, counters, update_target=True)
190+
191+
def _apply_rename(self, node, prefix, counters, update_target=False):
192+
new_name = f"{prefix}_{counters[prefix]}"
193+
counters[prefix] += 1
194+
node.name = new_name
195+
if update_target:
196+
node.target = new_name
197+
198+
def _is_weight_node(self, node, real_arg, model_path):
199+
is_not_data_input = not self.data_input_predicator(model_path, node.name)
200+
is_parameter_type = (
201+
node.type is not None
202+
and isinstance(node.type, type)
203+
and issubclass(node.type, torch.nn.parameter.Parameter)
204+
)
205+
is_parameter_value = isinstance(real_arg, torch.nn.Parameter)
206+
return is_not_data_input or is_parameter_type or is_parameter_value

graph_net/torch/test_compiler.py

100644100755
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
from graph_net.torch.backend.range_decomposer_validator_backend import (
2626
RangeDecomposerValidatorBackend,
2727
)
28+
from graph_net.torch.backend.graph_variable_renamer_validator_backend import (
29+
GraphVariableRenamerValidatorBackend,
30+
)
2831
from graph_net import test_compiler_util
2932
from graph_net import path_utils
3033

@@ -38,6 +41,7 @@
3841
"nope": NopeBackend(),
3942
"unstable_to_stable": UnstableToStableBackend(),
4043
"range_decomposer_validator": RangeDecomposerValidatorBackend(),
44+
"graph_variable_renamer_validator": GraphVariableRenamerValidatorBackend(),
4145
}
4246

4347

0 commit comments

Comments
 (0)