Skip to content

Commit f7f3d2a

Browse files
committed
add a script to find fully fusionable subgraph
1 parent f131cfb commit f7f3d2a

File tree

4 files changed

+197
-18
lines changed

4 files changed

+197
-18
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
MODEL_NAME=resnet18
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
decorator_config_json_str=$(cat <<EOF
10+
{
11+
"decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py",
12+
"decorator_config": {
13+
"name": "$MODEL_NAME",
14+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/graph_decompose_and_look_for_fully_fusionable_subgraph.py",
15+
"custom_extractor_config": {
16+
"output_dir": "/tmp/naive_decompose_workspace",
17+
"split_positions": [8,16,32],
18+
"group_head_and_tail": true,
19+
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
20+
"filter_config": {},
21+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
22+
"post_extract_process_class_name": "GraphFullyFusionable"
23+
}
24+
}
25+
}
26+
EOF
27+
)
28+
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)
29+
30+
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,4 @@
11
#!/bin/bash
2-
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
3-
GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR")
4-
PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR")
5-
6-
# 将项目根目录加入Python路径
7-
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
8-
9-
10-
11-
122

133
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
144
os.path.dirname(graph_net.__file__))")
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import os
2+
import torch
3+
import copy
4+
import random
5+
from graph_net.torch.decompose_util import convert_to_submodules_graph
6+
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
7+
import graph_net.imp_util as imp_util
8+
9+
10+
def generate_split_positions(max_pos=32, max_splits=8):
11+
num_splits = random.randint(3, max_splits)
12+
positions = random.sample(range(1, max_pos), num_splits)
13+
positions.sort()
14+
return positions
15+
16+
17+
class GraphExtractor:
18+
def __init__(
19+
self,
20+
config: dict,
21+
name,
22+
dynamic,
23+
mut_graph_codes=None,
24+
placeholder_auto_rename=False,
25+
):
26+
self.subgraph_counter = 0
27+
self.name = name
28+
self.dynamic = dynamic
29+
self.mut_graph_codes = mut_graph_codes
30+
self.placeholder_auto_rename = placeholder_auto_rename
31+
self.config = self.make_config(**config)
32+
self.last_post_process_result = False
33+
34+
def make_config(
35+
self,
36+
split_positions=(),
37+
group_head_and_tail=False,
38+
chain_style=False,
39+
output_dir="./tmp/naive_decomposer_dir",
40+
filter_path=None,
41+
filter_config=None,
42+
post_extract_process_path=None,
43+
post_extract_process_class_name=None,
44+
):
45+
for pos in split_positions:
46+
assert isinstance(
47+
pos, int
48+
), f"split_positions should be list of int, {split_positions=}"
49+
return {
50+
"split_positions": split_positions,
51+
"group_head_and_tail": group_head_and_tail,
52+
"chain_style": chain_style,
53+
"output_dir": output_dir,
54+
"filter_path": filter_path,
55+
"filter_config": filter_config if filter_config is not None else {},
56+
"post_extract_process_path": post_extract_process_path,
57+
"post_extract_process_class_name": post_extract_process_class_name,
58+
}
59+
60+
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
61+
max_retries = 20
62+
for i in range(max_retries):
63+
print(f"--- Attempt {i+1} ---")
64+
self.last_post_process_result = False
65+
config = {
66+
k: v
67+
for k, v in self.config.items()
68+
if k in {"split_positions", "group_head_and_tail", "chain_style"}
69+
}
70+
print(f"Current Config: {config['split_positions']}")
71+
72+
gm_to_process = copy.deepcopy(gm)
73+
74+
rewrited_gm = convert_to_submodules_graph(
75+
gm_to_process,
76+
submodule_hook=self.get_naive_decomposer_extractor,
77+
**config,
78+
)
79+
80+
try:
81+
rewrited_gm(*sample_inputs)
82+
except Exception as e:
83+
print(f"Run failed: {e}")
84+
self.last_post_process_result = False
85+
if self.last_post_process_result:
86+
print("Success! Subgraph is fully fusionable.")
87+
break
88+
else:
89+
print("Failed. Generating new split positions...")
90+
self.config["split_positions"] = generate_split_positions()
91+
92+
if i == max_retries - 1:
93+
print("failed to find a fully fusionable subgraph")
94+
return rewrited_gm
95+
96+
def get_naive_decomposer_extractor(self, submodule, seq_no):
97+
return NaiveDecomposerExtractor(self, submodule, seq_no)
98+
99+
100+
class NaiveDecomposerExtractor(torch.nn.Module):
101+
def __init__(self, parent_graph_extractor, submodule, seq_no):
102+
super().__init__()
103+
self.parent_graph_extractor = parent_graph_extractor
104+
self.submodule = submodule
105+
self.seq_no = seq_no
106+
self.extracted = False
107+
name = f"{parent_graph_extractor.name}_{self.seq_no}"
108+
self.model_name = name
109+
self.builtin_extractor = BuiltinGraphExtractor(
110+
name=name,
111+
dynamic=False,
112+
mut_graph_codes=[],
113+
placeholder_auto_rename=parent_graph_extractor.placeholder_auto_rename,
114+
workspace_path=self.parent_graph_extractor.config["output_dir"],
115+
)
116+
self.filter = self.make_filter(self.parent_graph_extractor.config)
117+
self.post_extract_process = self.make_post_extract_process(
118+
self.parent_graph_extractor.config
119+
)
120+
121+
def forward(self, *args):
122+
print("forward")
123+
if not self.extracted:
124+
if self.need_extract(self.submodule, args):
125+
self.builtin_extractor(self.submodule, args)
126+
success = self._post_extract_process()
127+
if success:
128+
print(f"Submodule {self.seq_no} failed fusion check.")
129+
self.parent_graph_extractor.last_post_process_result = True
130+
self.extracted = True
131+
return self.submodule(*args)
132+
133+
def need_extract(self, gm, sample_inputs):
134+
if self.filter is None:
135+
return True
136+
return self.filter(gm, sample_inputs)
137+
138+
def _post_extract_process(self):
139+
model_path = os.path.join(
140+
self.parent_graph_extractor.config["output_dir"], self.model_name
141+
)
142+
if self.post_extract_process:
143+
result = self.post_extract_process(model_path)
144+
else:
145+
result = True # 默认通过
146+
return result
147+
148+
def make_filter(self, config):
149+
if config["filter_path"] is None:
150+
return None
151+
module = imp_util.load_module(config["filter_path"])
152+
return module.GraphFilter(config["filter_config"])
153+
154+
def make_post_extract_process(self, config):
155+
if config["post_extract_process_path"] is None:
156+
return None
157+
module = imp_util.load_module(config["post_extract_process_path"])
158+
return module.GraphFullyFusionable(config["post_extract_process_path"])

graph_net/torch/post_extract_process_count_kernels.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from graph_net.torch import utils
22
import importlib.util
33
import torch
4+
import shutil
45
from typing import Type
56
from torch.profiler import profile, record_function, ProfilerActivity
67

@@ -29,25 +30,25 @@ def __call__(self, model_path=None):
2930
model(**state_dict)
3031
except Exception as e:
3132
print(f"failed in running model:{e}")
32-
# print(f"removing: {model_path}")
33-
# shutil.rmtree(model_path)
33+
print(f"removing: {model_path}")
34+
shutil.rmtree(model_path)
3435
return False
3536
# try to compile the model
3637
try:
3738
compiled_model = torch.compile(model)
3839
except Exception as e:
3940
print(f"failed in compiling model:{e}")
40-
# print(f"removing: {model_path}")
41-
# shutil.rmtree(model_path)
41+
print(f"removing: {model_path}")
42+
shutil.rmtree(model_path)
4243
return False
4344
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
4445
if compiled_num_of_kernels == 1:
45-
print(model_path, "can be fully integrated")
46+
print(model_path, "can be fully integrated!!!!!!!!!!!")
4647
return True
4748
else:
48-
print(f"{model_path} can not be fully integrated")
49-
# print(f"removing: {model_path}")
50-
# shutil.rmtree(model_path)
49+
print(f"{model_path} can not be fully integrated, to be removed...")
50+
print(f"removing: {model_path}")
51+
shutil.rmtree(model_path)
5152
return False
5253

5354

0 commit comments

Comments
 (0)