Skip to content

Commit 06b8dc2

Browse files
authored
Naive decompose minor fix (#344)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * minor fix for naive_graph_decomposer
1 parent 44ca1a5 commit 06b8dc2

File tree

6 files changed

+151
-45
lines changed

6 files changed

+151
-45
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/bin/bash
2+
set -x
3+
4+
# input model path
5+
MODEL_PATH_IN_SAMPLES=/timm/resnet18
6+
# extract subgraph 0-8, 8-16
7+
read -r -d '' json_str <<'EOF'
8+
{
9+
"output_dir": "/tmp/naive_decompose_workspace",
10+
"split_positions": [2, 4],
11+
"group_head_and_tail": false,
12+
"chain_style": true
13+
}
14+
EOF
15+
CONFIG=$(echo $json_str | base64 -w 0)
16+
17+
mkdir -p /tmp/naive_decompose_workspace
18+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
19+
os.path.dirname(graph_net.__file__))")
20+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --custom-extractor-path=$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py --custom-extractor-config=$CONFIG
Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
#!/bin/bash
2+
23
# input model path
34
MODEL_PATH_IN_SAMPLES=/timm/resnet18
4-
# output model path
5-
OUTPUT_DIR=/tmp/naive_decompose_workspace
5+
read -r -d '' json_str <<'EOF'
6+
{
7+
"output_dir": "/tmp/naive_decompose_workspace",
8+
"split_positions": [8, 32],
9+
"group_head_and_tail": true
10+
}
11+
EOF
12+
CONFIG=$(echo $json_str | base64 -w 0)
613

7-
mkdir -p $OUTPUT_DIR
8-
# extract subgraph 0-8, 8-16
9-
export GRAPH_NET_NAIVE_DECOMPOSER_SPLIT_POS=0,8,16
10-
export GRAPH_NET_EXTRACT_WORKSPACE=$OUTPUT_DIR
14+
mkdir -p /tmp/naive_decompose_workspace
1115
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
1216
os.path.dirname(graph_net.__file__))")
13-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --custom-extractor-path=$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py
17+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --custom-extractor-path=$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py --custom-extractor-config=$CONFIG

graph_net/torch/decompose_util.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@ def convert_to_submodules_graph(
99
original_gm: torch.fx.GraphModule,
1010
split_positions: list[int],
1111
submodule_hook=None,
12-
submodule_name_prefix="extraced_submodule",
12+
submodule_name_prefix="extracted_submodule",
13+
chain_style=False,
1314
group_head_and_tail=True,
1415
):
16+
"""
17+
chain_style=True: decompose original_gm into g0 * g1 * g2 * g3
18+
"""
1519
original_gm = copy.deepcopy(original_gm)
1620
num_placeholders = len(
1721
[node for node in original_gm.graph.nodes if node.op == "placeholder"]
@@ -68,6 +72,12 @@ def get_end_node_idx(range_idx):
6872
return i + 1
6973
raise NotImplementedError("Dead code.")
7074

75+
def print_submodule_call(prompt, gm):
76+
submodule_call_stmts = [
77+
stmt for stmt in gm.code.split("\n") if "self.extracted_submodule" in stmt
78+
]
79+
print(f"{prompt} ", submodule_call_stmts)
80+
7181
for range_idx in range(len(range_idx2submodule_body_nodes)):
7282
(
7383
submodule_input_nodes,
@@ -76,6 +86,7 @@ def get_end_node_idx(range_idx):
7686
original_gm=original_gm,
7787
start_node_idx=get_start_node_idx(range_idx),
7888
end_node_idx=get_end_node_idx(range_idx),
89+
chain_style=chain_style,
7990
)
8091

8192
def get_input_nodes(range_idx):
@@ -136,9 +147,14 @@ def get_output_nodes(range_idx):
136147
# Erase old nodes
137148
for node in reversed(get_body_nodes(range_idx)):
138149
original_gm.graph.erase_node(node)
150+
# print_submodule_call("(fx) after Erase old nodes", original_gm)
151+
152+
# print_submodule_call("(fx) before recompile", original_gm)
139153

140154
original_gm.recompile()
141155

156+
# print_submodule_call("(fx) after recompile", original_gm)
157+
142158
return original_gm
143159

144160

@@ -147,7 +163,7 @@ def fold_range_to_submodule(
147163
start_node_idx: int,
148164
end_node_idx: int,
149165
submodule_hook=None,
150-
submodule_name="extraced_submodule",
166+
submodule_name="extracted_submodule",
151167
group_head_and_tail=True,
152168
):
153169
return convert_to_submodules_graph(
@@ -170,6 +186,7 @@ def _get_submodule_inputs_and_outputs(
170186
original_gm: torch.fx.GraphModule,
171187
start_node_idx: int,
172188
end_node_idx: int,
189+
chain_style=False,
173190
):
174191
count_ctx = NodeProducedOrConsumedCountCtx(
175192
defaultdict(int),
@@ -179,7 +196,11 @@ def _get_submodule_inputs_and_outputs(
179196
node_list = list(original_gm.graph.nodes)
180197

181198
def get_related_node(node):
182-
yield from node.args
199+
for arg in node.args:
200+
if isinstance(arg, tuple):
201+
yield from arg
202+
else:
203+
yield arg
183204
yield node
184205

185206
for node in node_list[0:start_node_idx]:
@@ -194,19 +215,33 @@ def get_related_node(node):
194215
for related_node in get_related_node(node):
195216
count_ctx.node2after_output[related_node] += 1
196217

197-
input_nodes = [
198-
node
199-
for node in node_list
200-
if count_ctx.node2before_input[node] > 0
201-
if count_ctx.node2body[node] > 0
202-
]
203-
204-
output_nodes = [
205-
node
206-
for node in node_list
207-
if not (count_ctx.node2before_input[node] > 0)
208-
if count_ctx.node2body[node] > 0
209-
if count_ctx.node2after_output[node] > 0
210-
]
218+
if chain_style:
219+
input_nodes = [
220+
node
221+
for node in node_list
222+
if (count_ctx.node2before_input[node] > 0)
223+
if (count_ctx.node2body[node] > 0 or count_ctx.node2after_output[node] > 0)
224+
]
225+
input_nodes_set = set(input_nodes)
226+
output_nodes = [
227+
node
228+
for node in node_list
229+
if (count_ctx.node2before_input[node] > 0 or count_ctx.node2body[node] > 0)
230+
if (count_ctx.node2after_output[node] > 0)
231+
]
232+
else:
233+
input_nodes = [
234+
node
235+
for node in node_list
236+
if count_ctx.node2before_input[node] > 0
237+
if count_ctx.node2body[node] > 0
238+
]
239+
output_nodes = [
240+
node
241+
for node in node_list
242+
if not (count_ctx.node2before_input[node] > 0)
243+
if count_ctx.node2body[node] > 0
244+
if count_ctx.node2after_output[node] > 0
245+
]
211246

212247
return input_nodes, output_nodes

graph_net/torch/extractor.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,23 @@
1515

1616
class GraphExtractor:
1717
def __init__(
18-
self, name, dynamic, mut_graph_codes=None, placeholder_auto_rename=False
18+
self,
19+
name,
20+
dynamic,
21+
mut_graph_codes=None,
22+
placeholder_auto_rename=False,
23+
workspace_path=None,
1924
):
2025
self.subgraph_counter = 0
2126
self.name = name
2227
self.dynamic = dynamic
2328
self.mut_graph_codes = mut_graph_codes
2429
self.placeholder_auto_rename = placeholder_auto_rename
25-
self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
30+
self.workspace_path = (
31+
workspace_path
32+
if workspace_path is not None
33+
else os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
34+
)
2635
if not self.workspace_path:
2736
raise EnvironmentError(
2837
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
@@ -130,7 +139,8 @@ def extract(
130139
dynamic=True,
131140
mut_graph_codes=None,
132141
placeholder_auto_rename=False,
133-
custom_extractor_path=None,
142+
custom_extractor_path: str = None,
143+
custom_extractor_config: str = None,
134144
):
135145
"""
136146
Extract computation graphs from PyTorch nn.Module.
@@ -200,19 +210,20 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
200210
>>>
201211
"""
202212

203-
def get_graph_extractor_cls():
213+
def get_graph_extractor_maker():
204214
if custom_extractor_path is None:
205215
return GraphExtractor
206216
import importlib.util as imp
207217

208218
spec = imp.spec_from_file_location("graph_extractor", custom_extractor_path)
209219
graph_extractor = imp.module_from_spec(spec)
210220
spec.loader.exec_module(graph_extractor)
211-
return graph_extractor.GraphExtractor
221+
cls = graph_extractor.GraphExtractor
222+
return lambda *args, **kwargs: cls(custom_extractor_config, *args, **kwargs)
212223

213224
def wrapper(model: torch.nn.Module):
214225
assert isinstance(model, torch.nn.Module), f"{type(model)=}"
215-
extractor = get_graph_extractor_cls()(
226+
extractor = get_graph_extractor_maker()(
216227
name, dynamic, mut_graph_codes, placeholder_auto_rename
217228
)
218229
# return torch.compile(backend=extractor, dynamic=dynamic)

graph_net/torch/naive_graph_decomposer.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import torch
33
import json
4+
import base64
45
import shutil
56
from typing import Union, Callable
67
from graph_net.torch import utils
@@ -10,36 +11,62 @@
1011

1112
class GraphExtractor:
1213
def __init__(
13-
self, name, dynamic, mut_graph_codes=None, placeholder_auto_rename=False
14+
self,
15+
config_str: str,
16+
name,
17+
dynamic,
18+
mut_graph_codes=None,
19+
placeholder_auto_rename=False,
1420
):
1521
self.subgraph_counter = 0
1622
self.name = name
1723
self.dynamic = dynamic
1824
self.mut_graph_codes = mut_graph_codes
1925
self.placeholder_auto_rename = placeholder_auto_rename
20-
self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
21-
if not self.workspace_path:
22-
raise EnvironmentError(
23-
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
24-
)
25-
split_pos_str = os.environ.get("GRAPH_NET_NAIVE_DECOMPOSER_SPLIT_POS")
26-
if split_pos_str is None:
27-
raise EnvironmentError(
28-
"Environment variable 'GRAPH_NET_NAIVE_DECOMPOSER_SPLIT_POS' is not set."
29-
)
30-
self.split_positions = [int(pos) for pos in split_pos_str.split(",")]
26+
self.config = self.make_config(**self.convert_to_dict(config_str))
27+
28+
def make_config(
29+
self,
30+
split_positions=(),
31+
group_head_and_tail=False,
32+
chain_style=False,
33+
output_dir="./tmp/naive_decomposer_dir",
34+
):
35+
for pos in split_positions:
36+
assert isinstance(
37+
pos, int
38+
), f"split_positions should be list of int, {split_positions=}"
39+
return {
40+
"split_positions": split_positions,
41+
"group_head_and_tail": group_head_and_tail,
42+
"chain_style": chain_style,
43+
"output_dir": output_dir,
44+
}
3145

3246
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
33-
return convert_to_submodules_graph(
47+
config = {
48+
k: v
49+
for k, v in self.config.items()
50+
if k in {"split_positions", "group_head_and_tail", "chain_style"}
51+
}
52+
rewrited_gm = convert_to_submodules_graph(
3453
gm,
35-
split_positions=self.split_positions,
3654
submodule_hook=self.get_naive_decomposer_extractor,
37-
group_head_and_tail=False,
55+
**config,
3856
)
57+
return rewrited_gm
3958

4059
def get_naive_decomposer_extractor(self, submodule, seq_no):
4160
return NaiveDecomposerExtractor(self, submodule, seq_no)
4261

62+
def convert_to_dict(self, config_str):
63+
if config_str is None:
64+
return {}
65+
config_str = base64.b64decode(config_str).decode("utf-8")
66+
config = json.loads(config_str)
67+
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
68+
return config
69+
4370

4471
class NaiveDecomposerExtractor(torch.nn.Module):
4572
def __init__(self, parent_graph_extractor, submodule, seq_no):
@@ -54,6 +81,7 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
5481
dynamic=False,
5582
mut_graph_codes=[],
5683
placeholder_auto_rename=parent_graph_extractor.placeholder_auto_rename,
84+
workspace_path=self.parent_graph_extractor.config["output_dir"],
5785
)
5886

5987
def forward(self, *args):

graph_net/torch/single_device_runner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def main(args):
6464
name=args.extract_name,
6565
dynamic=False,
6666
custom_extractor_path=args.custom_extractor_path,
67+
custom_extractor_config=args.custom_extractor_config,
6768
**dump_graph_options,
6869
)
6970
model = extract(**kwargs)(model)
@@ -123,5 +124,12 @@ def main(args):
123124
default=None,
124125
help="Custom extractor python file path",
125126
)
127+
parser.add_argument(
128+
"--custom-extractor-config",
129+
type=str,
130+
required=False,
131+
default=None,
132+
help="Custom extractor configuration string",
133+
)
126134
args = parser.parse_args()
127135
main(args=args)

0 commit comments

Comments
 (0)