Skip to content

Commit 5f01912

Browse files
authored
Fix Check fusible (#410)
* 1119 * 1120 * 1120.2 * model_path * remove unnecessary files and pre-committed * remove unnecessary files and pre-committed * 1121 remove unnecessary files * modify rev version * modify rev version * modify rev version * accuracy issues targeted * test script and modify feature * return set[str] * add logfile for test * filter can get the number of kernels in naive_graph_decomposer * post extract process feature * remove unnecessary code blocks and variables * modify the way of counting kernels used * modify the way of counting kernels used * modify script, rename files and variables * add failure protection and log output when removing directories * add a script to check fusability of a given model * add a script to check if a given model is fully fusable * add a script to check if a given model is fully fusable * a script to check if a given model is fully fusable * add a script to check if a given model is fully fusionable * add a script to find fully fusionable subgraph * find the biggest fully fusionable subgraph * find the biggest fusionable subgraph * add a script to get the biggest fully fusable subgraph * use tempfile, fix sys problem, remove unsless configs * find the biggest fully fusible subgraph * find the biggest fully fusable subgraph in a given graph * corrrect 'fusable' -> 'fusible' * remove a useless swp file
1 parent 95fc800 commit 5f01912

8 files changed

+142
-122
lines changed

graph_net/test/graph_decompose_and_look_for_fully_fusable_subgraph_test.sh renamed to graph_net/test/graph_decompose_and_look_for_fully_fusible_subgraph_test.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,19 @@ decorator_config_json_str=$(cat <<EOF
1111
"decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py",
1212
"decorator_config": {
1313
"name": "$MODEL_NAME",
14-
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/fully_fusable_subgraph_extractor.py",
14+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/fully_fusible_subgraph_extractor.py",
1515
"custom_extractor_config": {
16+
"output_dir": "/tmp/find_fully_fusible_output",
1617
"split_positions": [],
1718
"group_head_and_tail": true,
18-
"max_step": 5,
19+
"max_step": 3,
1920
"min_step": 2,
20-
"max_nodes": 6
21+
"max_nodes": 5
2122
}
2223
}
2324
}
2425
EOF
2526
)
2627
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)
2728

28-
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG
29+
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG

graph_net/test/naive_decomposer_and_post_extract_process_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ decorator_config_json_str=$(cat <<EOF
2020
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
2121
"filter_config": {},
2222
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
23-
"post_extract_process_class_name": "GraphFullyFusable"
23+
"post_extract_process_class_name": "GraphFullyFusible"
2424
}
2525
}
2626
}

graph_net/torch/extractor.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch
33
import json
44
import shutil
5-
from typing import Union, Callable
65
from graph_net.torch import utils
76
from graph_net.torch.fx_graph_serialize_util import serialize_graph_module_to_str
87

@@ -82,7 +81,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
8281
subgraph_path = model_path
8382
else:
8483
if self.subgraph_counter == 1:
85-
subgraph_0_path = os.path.join(model_path, f"subgraph_0")
84+
subgraph_0_path = os.path.join(model_path, "subgraph_0")
8685
self.move_files(model_path, subgraph_0_path)
8786

8887
subgraph_path = os.path.join(
@@ -239,9 +238,12 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
239238

240239
extractor_config = make_extractor_config(extractor_config)
241240

242-
def get_graph_extractor_maker():
241+
def get_graph_extractor_maker(model_path):
243242
custom_extractor_path = extractor_config["custom_extractor_path"]
244243
custom_extractor_config = extractor_config["custom_extractor_config"]
244+
if custom_extractor_config is None:
245+
custom_extractor_config = {}
246+
custom_extractor_config["model_path"] = model_path
245247
if custom_extractor_path is None:
246248
return GraphExtractor
247249
import importlib.util as imp
@@ -254,7 +256,10 @@ def get_graph_extractor_maker():
254256

255257
def wrapper(model: torch.nn.Module):
256258
assert isinstance(model, torch.nn.Module), f"{type(model)=}"
257-
extractor = get_graph_extractor_maker()(
259+
model_path = None
260+
if hasattr(model, "__graph_net_file_path__"):
261+
model_path = os.path.dirname(model.__graph_net_file_path__)
262+
extractor = get_graph_extractor_maker(model_path)(
258263
name, dynamic, mut_graph_codes, placeholder_auto_rename
259264
)
260265
# return torch.compile(backend=extractor, dynamic=dynamic)

graph_net/torch/fully_fusable_subgraph_extractor.py

Lines changed: 0 additions & 98 deletions
This file was deleted.
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os
2+
import torch
3+
import graph_net
4+
import tempfile
5+
import shutil
6+
from graph_net.torch import constraint_util
7+
8+
9+
class GraphExtractor:
10+
def __init__(
11+
self,
12+
config: dict,
13+
name,
14+
dynamic,
15+
mut_graph_codes=None,
16+
placeholder_auto_rename=False,
17+
):
18+
self.subgraph_counter = 0
19+
self.name = name
20+
self.dynamic = dynamic
21+
self.mut_graph_codes = mut_graph_codes
22+
self.placeholder_auto_rename = placeholder_auto_rename
23+
self.config = self.make_config(**config)
24+
25+
def make_config(
26+
self,
27+
output_dir=None,
28+
split_positions=(),
29+
group_head_and_tail=False,
30+
chain_style=False,
31+
max_step=8,
32+
min_step=2,
33+
max_nodes=32,
34+
model_path=None,
35+
):
36+
for pos in split_positions:
37+
assert isinstance(
38+
pos, int
39+
), f"split_positions should be list of int, {split_positions=}"
40+
return {
41+
"output_dir": output_dir,
42+
"split_positions": split_positions,
43+
"group_head_and_tail": group_head_and_tail,
44+
"chain_style": chain_style,
45+
"max_step": max_step,
46+
"min_step": min_step,
47+
"max_nodes": max_nodes,
48+
"model_path": model_path,
49+
}
50+
51+
def _get_sub_ranges(self):
52+
assert self.config["min_step"] >= 1, "min_step must be greater than 1。"
53+
assert (
54+
self.config["max_step"] >= self.config["min_step"]
55+
), "max_step must be greater than min_step。"
56+
for step in reversed(
57+
range(self.config["min_step"], self.config["max_step"] + 1)
58+
):
59+
assert (
60+
self.config["min_step"] <= step <= self.config["max_step"]
61+
), "Internal error: step exceeds configuration range."
62+
for start_pos in range(self.config["max_nodes"] - step):
63+
end_pos = start_pos + step
64+
assert (
65+
0 <= start_pos < end_pos <= self.config["max_nodes"]
66+
), f"Invalid range generated: start={start_pos}, end={end_pos}, max={self.config['max_nodes']}"
67+
yield start_pos, end_pos
68+
69+
def _handle_success(self, temp_dir: str, start_pos: int, end_pos: int) -> str:
70+
target_name = f"{self.name}_start{start_pos}_end{end_pos}"
71+
target_path = os.path.join(
72+
self.config["output_dir"],
73+
target_name,
74+
)
75+
os.makedirs(target_path, exist_ok=True)
76+
shutil.move(temp_dir, target_path)
77+
return target_path
78+
79+
def _build_decompose_config(
80+
self, temp_dir: str, start_pos: int, end_pos: int
81+
) -> dict:
82+
self.config["split_positions"] = [start_pos, end_pos]
83+
graph_net_root = os.path.dirname(graph_net.__file__)
84+
85+
check_fusible_config = {
86+
"decorator_path": f"{graph_net_root}/torch/extractor.py",
87+
"decorator_config": {
88+
"name": f"{self.name}",
89+
"custom_extractor_path": f"{graph_net_root}/torch/naive_graph_decomposer.py",
90+
"custom_extractor_config": {
91+
"output_dir": temp_dir,
92+
"split_positions": self.config["split_positions"],
93+
"group_head_and_tail": False,
94+
"filter_path": f"{graph_net_root}/torch/naive_subgraph_filter.py",
95+
"filter_config": {},
96+
"post_extract_process_path": f"{graph_net_root}/torch/post_extract_process_count_kernels.py",
97+
"post_extract_process_class_name": "GraphFullyFusible",
98+
},
99+
},
100+
}
101+
return check_fusible_config
102+
103+
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
104+
for start_pos, end_pos in self._get_sub_ranges():
105+
with tempfile.TemporaryDirectory(
106+
prefix="_find_fusible_subgraph_"
107+
) as temp_dir:
108+
check_fusible_config = self._build_decompose_config(
109+
temp_dir, start_pos, end_pos
110+
)
111+
print("current split_positions:", self.config["split_positions"])
112+
success = constraint_util.RunModelPredicator(check_fusible_config)(
113+
self.config["model_path"]
114+
)
115+
if success:
116+
target_path = self._handle_success(temp_dir, start_pos, end_pos)
117+
print(
118+
f"SUCCESS in finding the biggest fully fusible subgraph. Result saved to: {target_path}"
119+
)
120+
break
121+
return gm.forward

graph_net/torch/naive_graph_decomposer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def make_config(
3131
filter_config=None,
3232
post_extract_process_path=None,
3333
post_extract_process_class_name=None,
34+
**kwargs,
3435
):
3536
for pos in split_positions:
3637
assert isinstance(

graph_net/torch/post_extract_process_count_kernels.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.profiler import profile, record_function, ProfilerActivity
77

88

9-
class GraphFullyFusable:
9+
class GraphFullyFusible:
1010
def __init__(self, config):
1111
self.config = config
1212

@@ -29,21 +29,17 @@ def __call__(self, model_path=None):
2929
# try to run the model
3030
try:
3131
model(**state_dict)
32-
except Exception as e:
33-
print(f"failed in running model:{e}")
32+
except Exception:
3433
sys.exit(1)
3534
# try to compile the model
3635
try:
3736
compiled_model = torch.compile(model)
38-
except Exception as e:
39-
print(f"failed in compiling model:{e}")
37+
except Exception:
4038
sys.exit(1)
4139
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
4240
if compiled_num_of_kernels == 1:
43-
print(model_path, "can be fully integrated!!!!!!!!!!!")
4441
sys.exit(0)
4542
else:
46-
print(f"{model_path} can not be fully integrated, to be removed...")
4743
sys.exit(1)
4844

4945

graph_net/torch/run_model.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
from . import utils
22
import argparse
33
import importlib.util
4-
import inspect
54
import torch
6-
import logging
7-
from pathlib import Path
8-
from typing import Type, Any
9-
import sys
5+
from typing import Type
106
import json
117
import base64
12-
from contextlib import contextmanager
138

149

1510
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
@@ -60,7 +55,6 @@ def main(args):
6055
assert model_class is not None
6156
model = model_class()
6257
print(f"{model_path=}")
63-
6458
decorator_config = _convert_to_dict(args.decorator_config)
6559
if "decorator_path" in decorator_config:
6660
model = _get_decorator(decorator_config)(model)
@@ -70,7 +64,7 @@ def main(args):
7064
use_dummy_inputs = get_flag_use_dummy_inputs(decorator_config)
7165
print(f"{use_dummy_inputs=}")
7266
state_dict = {k: replay_tensor(v, use_dummy_inputs) for k, v in params.items()}
73-
67+
model.__graph_net_file_path__ = model_path
7468
model(**state_dict)
7569

7670

0 commit comments

Comments
 (0)