Skip to content

Commit 0eb02b6

Browse files
authored
Look for the biggest fully fusible subgraph (#406)
* 1119 * 1120 * 1120.2 * model_path * remove unnecessary files and pre-committed * remove unnecessary files and pre-committed * 1121 remove unnecessary files * modify rev version * modify rev version * modify rev version * accuracy issues targeted * test script and modify feature * return set[str] * add logfile for test * filter can get the number of kernels in naive_graph_decomposer * post extract process feature * remove unnecessary code blocks and variables * modify the way of counting kernels used * modify the way of counting kernels used * modify script, rename files and variables * add failure protection and log output when removing directories * add a script to check fusability of a given model * add a script to check if a given model is fully fusable * add a script to check if a given model is fully fusable * a script to check if a given model is fully fusable * add a script to check if a given model is fully fusionable * add a script to find fully fusionable subgraph * find the biggest fully fusionable subgraph * find the biggest fusionable subgraph * add a script to get the biggest fully fusable subgraph * use tempfile, fix sys problem, remove unsless configs
1 parent 0aa1827 commit 0eb02b6

9 files changed

+269
-6
lines changed

graph_net/test/dimension_generalization_test.sh

100644100755
File mode changed.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
MODEL_NAME=resnet18
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
decorator_config_json_str=$(cat <<EOF
10+
{
11+
"decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py",
12+
"decorator_config": {
13+
"name": "$MODEL_NAME",
14+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/fully_fusable_subgraph_extractor.py",
15+
"custom_extractor_config": {
16+
"split_positions": [],
17+
"group_head_and_tail": true,
18+
"max_step": 5,
19+
"min_step": 2,
20+
"max_nodes": 6
21+
}
22+
}
23+
}
24+
EOF
25+
)
26+
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)
27+
28+
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/bin/bash
2+
# bash graph_net/test/naive_decomposer_and_post_extract_process_test.sh
3+
4+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
5+
os.path.dirname(graph_net.__file__))")
6+
7+
# input model path
8+
MODEL_NAME=resnet18
9+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
10+
decorator_config_json_str=$(cat <<EOF
11+
{
12+
"decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py",
13+
"decorator_config": {
14+
"name": "$MODEL_NAME",
15+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
16+
"custom_extractor_config": {
17+
"output_dir": "/tmp/naive_decompose_workspace",
18+
"split_positions": [8, 16, 32],
19+
"group_head_and_tail": true,
20+
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
21+
"filter_config": {},
22+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
23+
"post_extract_process_class_name": "GraphFullyFusable"
24+
}
25+
}
26+
}
27+
EOF
28+
)
29+
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)
30+
31+
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG

graph_net/test/naive_graph_decomposer_test.sh

100644100755
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/bin/bash
22

3-
43
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
54
os.path.dirname(graph_net.__file__))")
65

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
import torch
3+
import graph_net
4+
import tempfile
5+
from graph_net.torch import constraint_util
6+
7+
8+
class GraphExtractor:
9+
def __init__(
10+
self,
11+
config: dict,
12+
name,
13+
dynamic,
14+
mut_graph_codes=None,
15+
placeholder_auto_rename=False,
16+
):
17+
self.subgraph_counter = 0
18+
self.name = name
19+
self.dynamic = dynamic
20+
self.mut_graph_codes = mut_graph_codes
21+
self.placeholder_auto_rename = placeholder_auto_rename
22+
self.config = self.make_config(**config)
23+
24+
def make_config(
25+
self,
26+
split_positions=(),
27+
group_head_and_tail=False,
28+
chain_style=False,
29+
max_step=8,
30+
min_step=2,
31+
max_nodes=32,
32+
):
33+
for pos in split_positions:
34+
assert isinstance(
35+
pos, int
36+
), f"split_positions should be list of int, {split_positions=}"
37+
return {
38+
"split_positions": split_positions,
39+
"group_head_and_tail": group_head_and_tail,
40+
"chain_style": chain_style,
41+
"max_step": max_step,
42+
"min_step": min_step,
43+
"max_nodes": max_nodes,
44+
}
45+
46+
def _get_sub_ranges(self):
47+
for step in reversed(
48+
range(self.config["min_step"], self.config["max_step"] + 1)
49+
):
50+
for start_pos in range(self.config["max_nodes"] - step):
51+
end_pos = start_pos + step
52+
yield start_pos, end_pos
53+
54+
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
55+
temp_dir_obj = tempfile.TemporaryDirectory(prefix="_check_fusable_")
56+
temp_output_dir = temp_dir_obj.name
57+
found_fusable_subgraph = False
58+
print(f"Using temp output dir: {temp_output_dir}")
59+
for start_pos, end_pos in self._get_sub_ranges():
60+
self.config["split_positions"] = [start_pos, end_pos]
61+
print("current split_positions:", self.config["split_positions"])
62+
graph_net_root = os.path.dirname(graph_net.__file__)
63+
model_path = os.path.join(
64+
graph_net_root, "..", "samples", "timm", self.name
65+
)
66+
check_fusable_config = {
67+
"decorator_path": f"{graph_net_root}/torch/extractor.py",
68+
"decorator_config": {
69+
"name": f"{self.name}",
70+
"custom_extractor_path": f"{graph_net_root}/torch/naive_graph_decomposer.py",
71+
"custom_extractor_config": {
72+
"output_dir": temp_output_dir,
73+
"split_positions": self.config["split_positions"],
74+
"group_head_and_tail": False,
75+
"filter_path": f"{graph_net_root}/torch/naive_subgraph_filter.py",
76+
"filter_config": {},
77+
"post_extract_process_path": f"{graph_net_root}/torch/post_extract_process_count_kernels.py",
78+
"post_extract_process_class_name": "GraphFullyFusable",
79+
},
80+
},
81+
}
82+
success = constraint_util.RunModelPredicator(check_fusable_config)(
83+
model_path
84+
)
85+
if success:
86+
found_fusable_subgraph = True
87+
temp_dir_obj.cleanup = lambda: None
88+
print(
89+
f"SUCCESS in finding the biggest fully fusable subgraph saved in: {temp_output_dir}."
90+
)
91+
break
92+
else:
93+
print("Failed attempt. clean up the workspace and continue the search.")
94+
temp_dir_obj.cleanup()
95+
continue
96+
if not found_fusable_subgraph:
97+
print("No fusable subgraph found")
98+
return gm

graph_net/torch/naive_graph_decomposer.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import os
22
import torch
3-
import shutil
4-
from typing import Union, Callable
5-
from graph_net.torch import utils
63
from graph_net.torch.decompose_util import convert_to_submodules_graph
74
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
85
import graph_net.imp_util as imp_util
@@ -32,6 +29,8 @@ def make_config(
3229
output_dir="./tmp/naive_decomposer_dir",
3330
filter_path=None,
3431
filter_config=None,
32+
post_extract_process_path=None,
33+
post_extract_process_class_name=None,
3534
):
3635
for pos in split_positions:
3736
assert isinstance(
@@ -44,6 +43,8 @@ def make_config(
4443
"output_dir": output_dir,
4544
"filter_path": filter_path,
4645
"filter_config": filter_config if filter_config is not None else {},
46+
"post_extract_process_path": post_extract_process_path,
47+
"post_extract_process_class_name": post_extract_process_class_name,
4748
}
4849

4950
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
@@ -71,6 +72,7 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
7172
self.seq_no = seq_no
7273
self.extracted = False
7374
name = f"{parent_graph_extractor.name}_{self.seq_no}"
75+
self.model_name = name
7476
self.builtin_extractor = BuiltinGraphExtractor(
7577
name=name,
7678
dynamic=False,
@@ -79,21 +81,38 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
7981
workspace_path=self.parent_graph_extractor.config["output_dir"],
8082
)
8183
self.filter = self.make_filter(self.parent_graph_extractor.config)
84+
self.post_extract_process = self.make_post_extract_process(
85+
self.parent_graph_extractor.config
86+
)
8287

8388
def forward(self, *args):
8489
if not self.extracted:
8590
if self.need_extract(self.submodule, args):
8691
self.builtin_extractor(self.submodule, args)
8792
self.extracted = True
93+
self._post_extract_process()
8894
return self.submodule(*args)
8995

9096
def need_extract(self, gm, sample_inputs):
9197
if self.filter is None:
9298
return True
9399
return self.filter(gm, sample_inputs)
94100

101+
def _post_extract_process(self):
102+
model_path = os.path.join(
103+
self.parent_graph_extractor.config["output_dir"], self.model_name
104+
)
105+
return self.post_extract_process(model_path)
106+
95107
def make_filter(self, config):
96108
if config["filter_path"] is None:
97109
return None
98110
module = imp_util.load_module(config["filter_path"])
99111
return module.GraphFilter(config["filter_config"])
112+
113+
def make_post_extract_process(self, config):
114+
if config["post_extract_process_path"] is None:
115+
return None
116+
module = imp_util.load_module(config["post_extract_process_path"])
117+
cls = getattr(module, config["post_extract_process_class_name"])
118+
return cls(config["post_extract_process_path"])

graph_net/torch/naive_subgraph_filter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ def __init__(self, config):
33
self.config = config
44

55
def __call__(self, gm, sample_inputs):
6-
print(f"GraphFilter\n{gm.code}")
76
return True
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from graph_net.torch import utils
2+
import importlib.util
3+
import torch
4+
import sys
5+
from typing import Type
6+
from torch.profiler import profile, record_function, ProfilerActivity
7+
8+
9+
class GraphFullyFusable:
10+
def __init__(self, config):
11+
self.config = config
12+
13+
def __call__(self, model_path=None):
14+
torch._dynamo.reset()
15+
if model_path is None:
16+
sys.exit(1)
17+
# model
18+
model_class = load_class_from_file(
19+
f"{model_path}/model.py", class_name="GraphModule"
20+
)
21+
assert model_class is not None
22+
model = model_class()
23+
# print(f"{model_path=}")
24+
25+
inputs_params = utils.load_converted_from_text(f"{model_path}")
26+
params = inputs_params["weight_info"]
27+
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
28+
29+
# try to run the model
30+
try:
31+
model(**state_dict)
32+
except Exception as e:
33+
print(f"failed in running model:{e}")
34+
sys.exit(1)
35+
# try to compile the model
36+
try:
37+
compiled_model = torch.compile(model)
38+
except Exception as e:
39+
print(f"failed in compiling model:{e}")
40+
sys.exit(1)
41+
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
42+
if compiled_num_of_kernels == 1:
43+
print(model_path, "can be fully integrated!!!!!!!!!!!")
44+
sys.exit(0)
45+
else:
46+
print(f"{model_path} can not be fully integrated, to be removed...")
47+
sys.exit(1)
48+
49+
50+
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
51+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
52+
unnamed = importlib.util.module_from_spec(spec)
53+
spec.loader.exec_module(unnamed)
54+
model_class = getattr(unnamed, class_name, None)
55+
return model_class
56+
57+
58+
def count_kernels(model, sample_inputs) -> int:
59+
"""
60+
Count the number of CUDA kernel launches performed during a model's forward pass.
61+
62+
Args:
63+
model(graph models)
64+
sample_inputs(tensors)
65+
66+
Returns:
67+
int: The number of kernels used.
68+
69+
Behavior:
70+
- Runs the model once inside a PyTorch profiler context.
71+
- Identifies the event with key = 'cudaLaunchKernel', which corresponds
72+
to the number of CUDA kernel launches.
73+
"""
74+
model.eval()
75+
# Use PyTorch Profiler
76+
77+
with profile(
78+
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU],
79+
record_shapes=True,
80+
) as prof:
81+
with record_function("model_inference"):
82+
_ = model(**sample_inputs)
83+
events = prof.key_averages()
84+
85+
total_count = 0
86+
for e in events:
87+
if e.key == "cuLaunchKernel" or e.key == "cudaLaunchKernel":
88+
total_count += e.count
89+
return total_count

samples/timm/resnet18/input_tensor_constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sympy import Symbol, Expr, Rel, Eq
1+
from sympy import Symbol
22

33
S0 = Symbol("S0")
44
S1 = Symbol("S1")

0 commit comments

Comments
 (0)