Skip to content

Commit 353e7bd

Browse files
committed
find the biggest fully fusionable subgraph
1 parent f7f3d2a commit 353e7bd

File tree

2 files changed

+43
-45
lines changed

2 files changed

+43
-45
lines changed

graph_net/test/graph_decompose_and_look_for_fully_fusionable_subgraph_test.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ decorator_config_json_str=$(cat <<EOF
1414
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/graph_decompose_and_look_for_fully_fusionable_subgraph.py",
1515
"custom_extractor_config": {
1616
"output_dir": "/tmp/naive_decompose_workspace",
17-
"split_positions": [8,16,32],
17+
"split_positions": [8,16,17,18,19,20,32],
1818
"group_head_and_tail": true,
1919
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
2020
"filter_config": {},
2121
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
22-
"post_extract_process_class_name": "GraphFullyFusionable"
22+
"post_extract_process_class_name": "GraphFullyFusionable",
23+
"max_step": 4
2324
}
2425
}
2526
}

graph_net/torch/graph_decompose_and_look_for_fully_fusionable_subgraph.py

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,11 @@
11
import os
22
import torch
33
import copy
4-
import random
54
from graph_net.torch.decompose_util import convert_to_submodules_graph
65
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
76
import graph_net.imp_util as imp_util
87

98

10-
def generate_split_positions(max_pos=32, max_splits=8):
11-
num_splits = random.randint(3, max_splits)
12-
positions = random.sample(range(1, max_pos), num_splits)
13-
positions.sort()
14-
return positions
15-
16-
179
class GraphExtractor:
1810
def __init__(
1911
self,
@@ -30,6 +22,7 @@ def __init__(
3022
self.placeholder_auto_rename = placeholder_auto_rename
3123
self.config = self.make_config(**config)
3224
self.last_post_process_result = False
25+
self.decompose_finished = False
3326

3427
def make_config(
3528
self,
@@ -41,6 +34,7 @@ def make_config(
4134
filter_config=None,
4235
post_extract_process_path=None,
4336
post_extract_process_class_name=None,
37+
max_step=8,
4438
):
4539
for pos in split_positions:
4640
assert isinstance(
@@ -55,45 +49,46 @@ def make_config(
5549
"filter_config": filter_config if filter_config is not None else {},
5650
"post_extract_process_path": post_extract_process_path,
5751
"post_extract_process_class_name": post_extract_process_class_name,
52+
"max_step": max_step,
5853
}
5954

6055
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
61-
max_retries = 20
62-
for i in range(max_retries):
63-
print(f"--- Attempt {i+1} ---")
64-
self.last_post_process_result = False
65-
config = {
66-
k: v
67-
for k, v in self.config.items()
68-
if k in {"split_positions", "group_head_and_tail", "chain_style"}
69-
}
70-
print(f"Current Config: {config['split_positions']}")
71-
72-
gm_to_process = copy.deepcopy(gm)
73-
74-
rewrited_gm = convert_to_submodules_graph(
75-
gm_to_process,
76-
submodule_hook=self.get_naive_decomposer_extractor,
77-
**config,
78-
)
79-
80-
try:
81-
rewrited_gm(*sample_inputs)
82-
except Exception as e:
83-
print(f"Run failed: {e}")
56+
for i in range(self.config["max_step"], -1, -1):
57+
start_pos = 0
58+
for start_pos in range(32 - i):
59+
end_pos = start_pos + i
60+
self.config["split_positions"] = [start_pos, end_pos]
61+
torch._dynamo.reset()
8462
self.last_post_process_result = False
63+
config = {
64+
k: v
65+
for k, v in self.config.items()
66+
if k in {"split_positions", "group_head_and_tail", "chain_style"}
67+
}
68+
print(f"Current Config: {config['split_positions']}")
69+
gm_to_process = copy.deepcopy(gm)
70+
rewrited_gm = convert_to_submodules_graph(
71+
gm_to_process,
72+
submodule_hook=self.get_naive_decomposer_extractor,
73+
**config,
74+
)
75+
try:
76+
rewrited_gm(*sample_inputs)
77+
except Exception as e:
78+
print(f"Run failed: {e}")
79+
self.last_post_process_result = False
80+
if self.last_post_process_result and self.decompose_finished:
81+
print("Success! Subgraph is fully fusionable.")
82+
break
8583
if self.last_post_process_result:
86-
print("Success! Subgraph is fully fusionable.")
8784
break
88-
else:
89-
print("Failed. Generating new split positions...")
90-
self.config["split_positions"] = generate_split_positions()
91-
92-
if i == max_retries - 1:
93-
print("failed to find a fully fusionable subgraph")
9485
return rewrited_gm
9586

96-
def get_naive_decomposer_extractor(self, submodule, seq_no):
87+
def get_naive_decomposer_extractor(
88+
self,
89+
submodule,
90+
seq_no,
91+
):
9792
return NaiveDecomposerExtractor(self, submodule, seq_no)
9893

9994

@@ -119,14 +114,16 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
119114
)
120115

121116
def forward(self, *args):
122-
print("forward")
123117
if not self.extracted:
124118
if self.need_extract(self.submodule, args):
125119
self.builtin_extractor(self.submodule, args)
126-
success = self._post_extract_process()
127-
if success:
128-
print(f"Submodule {self.seq_no} failed fusion check.")
120+
if self._post_extract_process() and self.seq_no == 1:
129121
self.parent_graph_extractor.last_post_process_result = True
122+
print("biggest fully fusionable subgraph found!!", self.model_name)
123+
if self.seq_no == len(
124+
self.parent_graph_extractor.config["split_positions"]
125+
):
126+
self.parent_graph_extractor.decompose_finished = True
130127
self.extracted = True
131128
return self.submodule(*args)
132129

0 commit comments

Comments
 (0)