Skip to content

Commit babdde5

Browse files
committed
Improve efficiency of test/fully_fusible_subgraph_extractor_test.sh
1 parent 6df0cd0 commit babdde5

8 files changed

+76
-69
lines changed

graph_net/config/small_sample_list_for_get_fusible_subgraph.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#samples/timm/dla46x_c.in1k
55
#samples/timm/mobilenetv1_100.ra4_e3600_r224_in1k
66
samples/timm/efficientnetv2_rw_s.ra2_in1k
7-
#samples/timm/vit_base_patch16_rope_ape_224.naver_in1k
7+
samples/timm/vit_base_patch16_rope_ape_224.naver_in1k
88
#samples/timm/fastvit_t8.apple_dist_in1k
99
#samples/timm/test_byobnet.r160_in1k
10-
#samples/timm/mambaout_base.in1k
10+
#samples/timm/mambaout_base.in1k

graph_net/test/new_graph_decompose_and_look_for_fully_fusible_subgraph_test.sh renamed to graph_net/test/fully_fusible_subgraph_extractor_test.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ config_json_str=$(cat <<EOF
1515
"handler_path": "$GRAPH_NET_ROOT/torch/fully_fusible_subgraph_extractor.py",
1616
"handler_class_name":"FullyFusibleSubgraphExtractor",
1717
"handler_config": {
18+
"resume": false,
1819
"model_path_prefix": "$GRAPH_NET_ROOT/../",
1920
"output_dir": "$OUTPUT_DIR",
20-
"split_positions": [],
21-
"group_head_and_tail": false,
22-
"chain_style": false,
21+
"nn_module_fully_fusible_decorator_path": "$GRAPH_NET_ROOT/torch/count_kernels_util.py",
22+
"nn_module_fully_fusible_decorator_class_name": "TorchSubModuleFullyFusibleDecorator",
2323
"max_step": 3,
2424
"min_step": 2,
2525
"max_nodes": 4

graph_net/test/naive_check_if_fully_fusionabale.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
99
checker_config_json_str=$(cat <<EOF
1010
{
1111
"post_extract_process_config": {
12-
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
12+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/count_kernels_util.py",
1313
"post_extract_process_class_name": "GraphFullyFusionable"
1414
}
1515
}

graph_net/test/naive_decomposer_and_post_extract_process_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ decorator_config_json_str=$(cat <<EOF
1919
"group_head_and_tail": true,
2020
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
2121
"filter_config": {},
22-
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
22+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/count_kernels_util.py",
2323
"post_extract_process_class_name": "GraphFullyFusible"
2424
}
2525
}

graph_net/torch/post_extract_process_count_kernels.py renamed to graph_net/torch/count_kernels_util.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
)
1313

1414

15-
class TorchNNModuleFullyFusibleDecorator:
15+
class TorchSubModuleFullyFusibleDecorator:
1616
def __init__(self, config):
1717
self.config = config
1818

19-
def __call__(self, module):
19+
def __call__(self, module, sub_module_idx):
2020
return TorchNNModuleFullyFusiblePredicator(module)
2121

2222

2323
class TorchNNModuleFullyFusiblePredicator(torch.nn.Module):
2424
def __init__(self, module):
25+
super().__init__()
2526
self.module = module
2627

2728
def forward(self, *inputs):

graph_net/torch/fully_fusible_graph_predicator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class FullyFusibleSubGraphPredicator:
4545
def __init__(self, config):
4646
if config is None:
4747
config = {}
48-
self.config = self._make_config(config)
48+
self.config = self._make_config(**config)
4949
self.nn_module_fully_fusible_decorator = (
5050
self._make_nn_module_fully_fusible_decorator(config)
5151
)
@@ -79,10 +79,10 @@ def _make_config(
7979
"nn_module_fully_fusible_decorator_config": nn_module_fully_fusible_decorator_config,
8080
}
8181

82-
def __call__(self, gm: torch.fx.GraphModule, start_node_idx, end_node_idx):
82+
def __call__(self, start_node_idx, end_node_idx):
8383
try:
8484
rewrited_gm: torch.fx.GraphModule = fold_range_to_submodule(
85-
gm,
85+
self.traced_module,
8686
start_node_idx=start_node_idx,
8787
end_node_idx=end_node_idx,
8888
submodule_hook=self.nn_module_fully_fusible_decorator,
Lines changed: 62 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import os
2+
import torch
23
from pathlib import Path
3-
import graph_net
44
import tempfile
55
import shutil
6-
from graph_net.torch import fully_fusible_graph_predicator
7-
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
8-
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
6+
from graph_net.torch.graph_decomposer import NaiveDecomposerExtractor
7+
from graph_net.torch.fully_fusible_graph_predicator import (
8+
FullyFusibleSubGraphPredicator,
9+
)
910
import logging
1011

1112
logger = logging.getLogger(__name__)
@@ -19,24 +20,22 @@ def __init__(self, config: dict = None):
1920

2021
def _make_config(
2122
self,
23+
nn_module_fully_fusible_decorator_path,
24+
nn_module_fully_fusible_decorator_class_name,
25+
nn_module_fully_fusible_decorator_config=None,
2226
output_dir=None,
23-
split_positions=(),
24-
group_head_and_tail=False,
25-
chain_style=False,
27+
resume: bool = True,
2628
max_step=8,
2729
min_step=2,
2830
max_nodes=32,
2931
model_path_prefix="",
3032
):
31-
for pos in split_positions:
32-
assert isinstance(
33-
pos, int
34-
), f"split_positions should be list of int, {split_positions=}"
3533
return {
3634
"output_dir": output_dir,
37-
"split_positions": split_positions,
38-
"group_head_and_tail": group_head_and_tail,
39-
"chain_style": chain_style,
35+
"resume": resume,
36+
"nn_module_fully_fusible_decorator_path": nn_module_fully_fusible_decorator_path,
37+
"nn_module_fully_fusible_decorator_class_name": nn_module_fully_fusible_decorator_class_name,
38+
"nn_module_fully_fusible_decorator_config": nn_module_fully_fusible_decorator_config,
4039
"max_step": max_step,
4140
"min_step": min_step,
4241
"max_nodes": max_nodes,
@@ -61,7 +60,9 @@ def _get_sub_ranges(self):
6160
), f"Invalid range generated: start={start_pos}, end={end_pos}, max={self.config['max_nodes']}"
6261
yield start_pos, end_pos
6362

64-
def _handle_success(self, temp_dir: str, rel_model_path: str) -> str:
63+
def _copy_from_tmp_dir_to_output_dir(
64+
self, temp_dir: str, rel_model_path: str
65+
) -> str:
6566
subdirs = list(Path(temp_dir).iterdir())
6667
assert len(subdirs) == 1
6768
temp_dir = str(subdirs[0])
@@ -74,57 +75,62 @@ def _handle_success(self, temp_dir: str, rel_model_path: str) -> str:
7475
return target_path
7576

7677
def _build_decompose_config(
77-
self, temp_dir: str, start_pos: int, end_pos: int, model_path_prefix
78+
self, temp_dir: str, start_pos: int, end_pos: int
7879
) -> dict:
79-
graph_net_root = os.path.dirname(graph_net.__file__)
80+
model_path_prefix = self.config["model_path_prefix"]
81+
decomposer_config = {
82+
"model_path_prefix": model_path_prefix,
83+
"output_dir": temp_dir,
84+
"split_positions": [start_pos, end_pos],
85+
"group_head_and_tail": False,
86+
}
87+
return decomposer_config
8088

81-
check_fusible_config = {
82-
"handler_path": f"{graph_net_root}/torch/graph_decomposer.py",
83-
"handler_class_name": "NaiveDecomposerExtractor",
84-
"handler_config": {
85-
"model_path_prefix": model_path_prefix,
86-
"output_dir": temp_dir,
87-
"split_positions": [start_pos, end_pos],
88-
"group_head_and_tail": False,
89-
"post_extract_process_path": f"{graph_net_root}/torch/post_extract_process_count_kernels.py",
90-
"post_extract_process_class_name": "ThrowExitStatusIfGraphFullyFusible",
91-
},
89+
def _get_fully_fusible_subgraph_predicator(self, model_path):
90+
config = {
91+
"model_path": model_path,
92+
"nn_module_fully_fusible_decorator_path": self.config[
93+
"nn_module_fully_fusible_decorator_path"
94+
],
95+
"nn_module_fully_fusible_decorator_class_name": self.config[
96+
"nn_module_fully_fusible_decorator_class_name"
97+
],
98+
"nn_module_fully_fusible_decorator_config": self.config[
99+
"nn_module_fully_fusible_decorator_config"
100+
],
92101
}
93-
return check_fusible_config
102+
return FullyFusibleSubGraphPredicator(config)
103+
104+
def _is_model_path_handled(self, rel_model_path):
105+
model_path = Path(self.config["output_dir"]) / rel_model_path
106+
return model_path.exists() and len(list(model_path.iterdir())) > 0
94107

95108
def __call__(self, rel_model_path):
109+
if self.config["resume"] and self._is_model_path_handled(rel_model_path):
110+
return
111+
torch.cuda.empty_cache()
96112
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
97-
module, inputs = get_torch_module_and_inputs(model_path)
98-
gm = parse_sole_graph_module(module, inputs)
113+
fully_fusible_subgraph_predicator = self._get_fully_fusible_subgraph_predicator(
114+
model_path
115+
)
99116
for start_pos, end_pos in self._get_sub_ranges():
117+
logger.warning("fully_fusible_subgraph_predicator-begin")
118+
success = fully_fusible_subgraph_predicator(start_pos, end_pos)
119+
logger.warning("fully_fusible_subgraph_predicator-end")
120+
if not success:
121+
continue
100122
with tempfile.TemporaryDirectory(
101123
prefix="_find_fusible_subgraph_"
102124
) as temp_dir:
103-
check_fusible_config = self._build_decompose_config(
104-
temp_dir, start_pos, end_pos, self.config["model_path_prefix"]
105-
)
106-
predicator_cls = (
107-
fully_fusible_graph_predicator.FullyFusibleGraphPredicator
108-
)
109-
predicator = predicator_cls(check_fusible_config)
110-
logger.warning("fully_fusible_graph_predicator-begin")
111-
success = predicator(model_path)
112-
logger.warning("fully_fusible_graph_predicator-end")
113-
if not success:
114-
continue
115125
decomposer_config = self._build_decompose_config(
116-
temp_dir, start_pos, end_pos, self.config["model_path_prefix"]
117-
)
118-
predicator_cls = (
119-
fully_fusible_graph_predicator.FullyFusibleGraphPredicator
126+
temp_dir, start_pos, end_pos
120127
)
121-
predicator = predicator_cls(decomposer_config)
122-
predicator(model_path)
123-
target_path = self._handle_success(temp_dir, rel_model_path)
124-
print(
125-
f"SUCCESS in finding the biggest fully fusible subgraph. Result saved to: {target_path}"
128+
naive_graph_decomposer = NaiveDecomposerExtractor(decomposer_config)
129+
logger.warning("naive_graph_decomposer-begin")
130+
naive_graph_decomposer(rel_model_path)
131+
logger.warning("naive_graph_decomposer-end")
132+
fully_fusible_destination_path = self._copy_from_tmp_dir_to_output_dir(
133+
temp_dir, rel_model_path
126134
)
127-
break
128-
else:
129-
logger.warning("fail to find fully fusible subgraph")
130-
return gm.forward
135+
print(f"{fully_fusible_destination_path=}")
136+
return

graph_net/torch/graph_decomposer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ def __init__(self, config: dict = None):
105105

106106
def _make_config(
107107
self,
108+
output_dir,
108109
split_positions=(),
109110
group_head_and_tail=False,
110111
chain_style=False,
111-
output_dir="./tmp/naive_decomposer_dir",
112112
filter_path=None,
113113
filter_config=None,
114114
post_extract_process_path=None,

0 commit comments

Comments
 (0)