Skip to content

Commit 517b86e

Browse files
committed
change the entry of naive_graph_decomposer from graph_net.torch.run_model to graph_net.model_path_handler
1 parent 0b1648a commit 517b86e

File tree

5 files changed

+183
-68
lines changed

5 files changed

+183
-68
lines changed

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,20 @@ os.path.dirname(graph_net.__file__))")
66
# input model path
77
MODEL_NAME=resnet18
88
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9-
decorator_config_json_str=$(cat <<EOF
9+
config_json_str=$(cat <<EOF
1010
{
11-
"decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py",
12-
"decorator_config": {
13-
"name": "$MODEL_NAME",
14-
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
15-
"custom_extractor_config": {
16-
"output_dir": "/tmp/naive_decompose_workspace",
17-
"split_positions": [8, 16, 32],
18-
"group_head_and_tail": true,
19-
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
20-
"filter_config": {}
21-
}
11+
"handler_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
12+
"handler_class_name": "NaiveDecomposerExtractor",
13+
"handler_config": {
14+
"output_dir": "/tmp/naive_decompose_workspace",
15+
"split_positions": [8, 16, 32],
16+
"group_head_and_tail": true,
17+
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
18+
"filter_config": {}
2219
}
2320
}
2421
EOF
2522
)
26-
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)
23+
CONFIG=$(echo $config_json_str | base64 -w 0)
2724

28-
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG
25+
python3 -m graph_net.model_path_handler --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,15 @@
1-
import logging
2-
import torch
31
import copy
42
import os
5-
import inspect
6-
from graph_net.tensor_meta import TensorMeta
73
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
8-
from graph_net.imp_util import load_module
9-
from dataclasses import asdict
4+
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
105

116

127
def parse_immutable_model_path_into_sole_graph_module(model_path):
138
model_path = os.path.realpath(model_path)
149
if model_path not in g_model_path2graph_module:
15-
module = _get_torch_module(model_path)
16-
tensor_metas = _get_tensor_metas(model_path)
17-
logging.warning("before _create_inputs_by_metas")
18-
inputs = _create_inputs_by_metas(module, tensor_metas)
19-
logging.warning("after _create_inputs_by_metas")
20-
logging.warning("before parse_sole_graph_module")
10+
module, inputs = get_torch_module_and_inputs(model_path)
2111
g_model_path2graph_module[model_path] = parse_sole_graph_module(module, inputs)
22-
logging.warning("after parse_sole_graph_module")
2312
return copy.deepcopy(g_model_path2graph_module[model_path])
2413

2514

26-
def _get_torch_module(model_path):
27-
py_module = load_module(f"{model_path}/model.py")
28-
torch_module_cls = py_module.GraphModule
29-
return torch_module_cls()
30-
31-
32-
def _get_tensor_metas(model_path):
33-
make = TensorMeta.unserialize_from_py_file
34-
return [
35-
*make(os.path.join(model_path, "input_meta.py")),
36-
*make(os.path.join(model_path, "weight_meta.py")),
37-
]
38-
39-
40-
def _create_inputs_by_metas(module, tensor_metas):
41-
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
42-
from graph_net.torch.utils import get_dummy_named_tensors
43-
44-
named_tensors = get_dummy_named_tensors(tensor_meta_attrs_list)
45-
name2tensor = {k: v for k, v in named_tensors}
46-
return tuple(
47-
name2tensor[name] for name in inspect.signature(module.forward).parameters
48-
)
49-
50-
5115
g_model_path2graph_module = {}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
import inspect
3+
from graph_net.tensor_meta import TensorMeta
4+
from graph_net.imp_util import load_module
5+
from dataclasses import asdict
6+
7+
8+
def get_torch_module_and_inputs(model_path):
9+
module = _get_torch_module(model_path)
10+
tensor_metas = _get_tensor_metas(model_path)
11+
inputs = _create_inputs_by_metas(module, tensor_metas)
12+
return module, inputs
13+
14+
15+
def _get_torch_module(model_path):
16+
py_module = load_module(f"{model_path}/model.py")
17+
torch_module_cls = py_module.GraphModule
18+
return torch_module_cls()
19+
20+
21+
def _get_tensor_metas(model_path):
22+
make = TensorMeta.unserialize_from_py_file
23+
return [
24+
*make(os.path.join(model_path, "input_meta.py")),
25+
*make(os.path.join(model_path, "weight_meta.py")),
26+
]
27+
28+
29+
def _create_inputs_by_metas(module, tensor_metas):
30+
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
31+
from graph_net.torch.utils import get_dummy_named_tensors
32+
33+
named_tensors = get_dummy_named_tensors(tensor_meta_attrs_list)
34+
name2tensor = {k: v for k, v in named_tensors}
35+
return tuple(
36+
name2tensor[name] for name in inspect.signature(module.forward).parameters
37+
)

graph_net/torch/fx_graph_parse_util.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,45 @@ def get_zip_filter_names():
7676
if name_from_signature != name_from_placeholder
7777
)
7878

79+
if len(get_zip_filter_names()) > 0 and set(get_input_names_from_signature()) == set(
80+
get_input_names_from_placeholder()
81+
):
82+
traced_module = _reorder_placeholders(
83+
traced_module, get_input_names_from_signature()
84+
)
85+
7986
zip_filter_names = get_zip_filter_names()
8087

8188
def zip_filter_names_str():
8289
for triple in zip_filter_names:
8390
print(triple)
8491
return "<printed before>"
8592

93+
from pathlib import Path
94+
95+
Path("/tmp/a.py").write_text(traced_module.code)
8696
assert len(zip_filter_names) == 0, f"{zip_filter_names_str()=}"
8797
return traced_module
98+
99+
100+
def _reorder_placeholders(gm, sorted_names):
101+
sorted_names = list(sorted_names)
102+
name2placeholder = {
103+
node.name: node for node in gm.graph.nodes if node.op == "placeholder"
104+
}
105+
for i, current_placeholder_name in enumerate(sorted_names):
106+
if i == 0:
107+
continue
108+
prev_node = name2placeholder[sorted_names[i - 1]]
109+
current_node = name2placeholder[current_placeholder_name]
110+
with gm.graph.inserting_after(prev_node):
111+
new_node = gm.graph.placeholder(current_node.name)
112+
# force rename
113+
new_node.name = current_node.name
114+
new_node.target = current_node.target
115+
current_node.replace_all_uses_with(new_node)
116+
name2placeholder[current_placeholder_name] = new_node
117+
gm.graph.erase_node(current_node)
118+
119+
gm.recompile()
120+
return gm

graph_net/torch/naive_graph_decomposer.py

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@
33
from graph_net.torch.decompose_util import convert_to_submodules_graph
44
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
55
import graph_net.imp_util as imp_util
6+
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
7+
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
68

79

810
class GraphExtractor:
11+
"""
12+
Used by graph_net.torch.run_model
13+
"""
14+
915
def __init__(
1016
self,
1117
config: dict,
@@ -66,29 +72,109 @@ def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
6672
return rewrited_gm
6773

6874
def get_naive_decomposer_extractor(self, submodule, seq_no):
69-
return NaiveDecomposerExtractor(self, submodule, seq_no)
75+
return NaiveDecomposerExtractorModule(
76+
config=self.config,
77+
parent_graph_name=self.name,
78+
submodule=submodule,
79+
seq_no=seq_no,
80+
)
81+
82+
83+
class NaiveDecomposerExtractor:
84+
"""
85+
Used by graph_net.model_path_handler
86+
"""
87+
88+
def __init__(self, config: dict = None):
89+
if config is None:
90+
config = {}
91+
self.config = self._make_config(**config)
92+
93+
def _make_config(
94+
self,
95+
split_positions=(),
96+
group_head_and_tail=False,
97+
chain_style=False,
98+
output_dir="./tmp/naive_decomposer_dir",
99+
filter_path=None,
100+
filter_config=None,
101+
post_extract_process_path=None,
102+
post_extract_process_class_name=None,
103+
post_extract_process_config=None,
104+
**kwargs,
105+
):
106+
if post_extract_process_config is None:
107+
post_extract_process_config = {}
108+
for pos in split_positions:
109+
assert isinstance(
110+
pos, int
111+
), f"split_positions should be list of int, {split_positions=}"
112+
return {
113+
"split_positions": split_positions,
114+
"group_head_and_tail": group_head_and_tail,
115+
"chain_style": chain_style,
116+
"output_dir": output_dir,
117+
"filter_path": filter_path,
118+
"filter_config": filter_config if filter_config is not None else {},
119+
"post_extract_process_path": post_extract_process_path,
120+
"post_extract_process_class_name": post_extract_process_class_name,
121+
"post_extract_process_config": post_extract_process_config,
122+
}
123+
124+
def __call__(self, model_path):
125+
config = {
126+
k: v
127+
for k, v in self.config.items()
128+
if k in {"split_positions", "group_head_and_tail", "chain_style"}
129+
}
130+
module, inputs = get_torch_module_and_inputs(model_path)
131+
gm = parse_sole_graph_module(module, inputs)
132+
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
133+
gm,
134+
submodule_hook=self.get_naive_decomposer_extractor(model_path),
135+
**config,
136+
)
137+
rewrited_gm(*inputs)
138+
139+
def get_naive_decomposer_extractor(self, model_path):
140+
def fn(submodule, seq_no):
141+
return NaiveDecomposerExtractorModule(
142+
config=self.config,
143+
parent_graph_name=os.path.basename(model_path),
144+
submodule=submodule,
145+
seq_no=seq_no,
146+
)
147+
148+
return fn
70149

71150

72-
class NaiveDecomposerExtractor(torch.nn.Module):
73-
def __init__(self, parent_graph_extractor, submodule, seq_no):
151+
class NaiveDecomposerExtractorModule(torch.nn.Module):
152+
def __init__(
153+
self,
154+
config: dict,
155+
parent_graph_name: str,
156+
submodule: torch.nn.Module,
157+
seq_no: int,
158+
):
74159
super().__init__()
75-
self.parent_graph_extractor = parent_graph_extractor
160+
self.config = config
76161
self.submodule = submodule
77162
self.seq_no = seq_no
78163
self.extracted = False
79-
name = f"{parent_graph_extractor.name}_{self.seq_no}"
80-
self.model_name = name
164+
if self.seq_no is None:
165+
self.model_name = parent_graph_name
166+
else:
167+
submodule_name = f"{parent_graph_name}_{self.seq_no}"
168+
self.model_name = submodule_name
81169
self.builtin_extractor = BuiltinGraphExtractor(
82-
name=name,
170+
name=submodule_name,
83171
dynamic=False,
84172
mut_graph_codes=[],
85-
placeholder_auto_rename=parent_graph_extractor.placeholder_auto_rename,
86-
workspace_path=self.parent_graph_extractor.config["output_dir"],
87-
)
88-
self.filter = self.make_filter(self.parent_graph_extractor.config)
89-
self.post_extract_process = self.make_post_extract_process(
90-
self.parent_graph_extractor.config
173+
placeholder_auto_rename=False,
174+
workspace_path=self.config["output_dir"],
91175
)
176+
self.filter = self.make_filter(self.config)
177+
self.post_extract_process = self.make_post_extract_process(self.config)
92178

93179
def forward(self, *args):
94180
if not self.extracted:
@@ -104,9 +190,7 @@ def need_extract(self, gm, sample_inputs):
104190
return self.filter(gm, sample_inputs)
105191

106192
def _post_extract_process(self):
107-
model_path = os.path.join(
108-
self.parent_graph_extractor.config["output_dir"], self.model_name
109-
)
193+
model_path = os.path.join(self.config["output_dir"], self.model_name)
110194
return self.post_extract_process(model_path)
111195

112196
def make_filter(self, config):

0 commit comments

Comments
 (0)