Skip to content

Commit 93fabbf

Browse files
authored
Merge pull request #1 from lixinqi/lxq_fusibletest
improve fully_fusible_subgraph_extractor.py efficiency
2 parents 7dbb6e9 + f71b56b commit 93fabbf

File tree

5 files changed

+91
-40
lines changed

5 files changed

+91
-40
lines changed

graph_net/torch/constraint_util.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,3 @@ def __call__(self, model_path):
6161
decorator_config = b64_encoded_bytes.decode("utf-8")
6262
cmd = f"{sys.executable} -m graph_net.torch.run_model --model-path {model_path} --decorator-config {decorator_config}"
6363
return os.system(cmd) == 0
64-
65-
66-
class FusibleSubgraphPredicator:
67-
def __init__(self, config=None):
68-
if config is None:
69-
config = {}
70-
self.config = config
71-
72-
def __call__(self, model_path):
73-
import json
74-
import base64
75-
76-
json_string = json.dumps(self.config)
77-
json_bytes = json_string.encode("utf-8")
78-
b64_encoded_bytes = base64.b64encode(json_bytes)
79-
predicator_config = b64_encoded_bytes.decode("utf-8")
80-
cmd = f"{sys.executable} -m graph_net.model_path_handler --model-path {model_path} --handler-config {predicator_config}"
81-
return os.system(cmd) == 0
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import traceback
2+
import logging
3+
from graph_net.torch.graph_decomposer import NaiveDecomposerExtractor
4+
from graph_net.torch.graph_fusibility_status import (
5+
GraphFusibilityStatus,
6+
GraphFusibility,
7+
)
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class FullyFusibleGraphPredicator:
13+
def __init__(self, config=None):
14+
if config is None:
15+
config = {}
16+
self.config = config
17+
handler_config = self.config["handler_config"]
18+
self.decomposer_extractor = NaiveDecomposerExtractor(handler_config)
19+
20+
def __call__(self, model_path):
21+
try:
22+
self.decomposer_extractor(model_path)
23+
except GraphFusibilityStatus as status:
24+
if status.graph_fusibility == GraphFusibility.kFullyFusible:
25+
return True
26+
elif status.graph_fusibility == GraphFusibility.kNotFullyFusible:
27+
return False
28+
else:
29+
raise NotImplementedError(f"{status.graph_fusibility=}")
30+
except Exception:
31+
print("\n--- Custom Error Handler ---")
32+
traceback.print_exc()
33+
print("--------------------------\n")
34+
return False

graph_net/torch/fully_fusible_subgraph_extractor.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
2+
from pathlib import Path
23
import graph_net
34
import tempfile
45
import shutil
5-
from graph_net.torch import constraint_util
6+
from graph_net.torch import fully_fusible_graph_predicator
67
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
78
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
89
import logging
@@ -60,20 +61,16 @@ def _get_sub_ranges(self):
6061
), f"Invalid range generated: start={start_pos}, end={end_pos}, max={self.config['max_nodes']}"
6162
yield start_pos, end_pos
6263

63-
def _handle_success(
64-
self, temp_dir: str, start_pos: int, end_pos: int, model_name
65-
) -> str:
66-
target_name = f"{model_name}_start{start_pos}_end{end_pos}"
64+
def _handle_success(self, temp_dir: str, rel_model_path: str) -> str:
65+
subdirs = list(Path(temp_dir).iterdir())
66+
assert len(subdirs) == 1
67+
temp_dir = str(subdirs[0])
6768
target_path = os.path.join(
6869
self.config["output_dir"],
69-
target_name,
70+
rel_model_path,
7071
)
7172
os.makedirs(target_path, exist_ok=True)
72-
# shutil.move(temp_dir, target_path)
73-
for item in os.listdir(temp_dir):
74-
source = os.path.join(temp_dir, item)
75-
destination = os.path.join(target_path, item)
76-
shutil.move(source, destination)
73+
shutil.copytree(temp_dir, target_path, dirs_exist_ok=True)
7774
return target_path
7875

7976
def _build_decompose_config(
@@ -90,7 +87,7 @@ def _build_decompose_config(
9087
"split_positions": [start_pos, end_pos],
9188
"group_head_and_tail": False,
9289
"post_extract_process_path": f"{graph_net_root}/torch/post_extract_process_count_kernels.py",
93-
"post_extract_process_class_name": "GraphFullyFusible",
90+
"post_extract_process_class_name": "ThrowExitStatusIfGraphFullyFusible",
9491
},
9592
}
9693
return check_fusible_config
@@ -106,14 +103,14 @@ def __call__(self, rel_model_path):
106103
check_fusible_config = self._build_decompose_config(
107104
temp_dir, start_pos, end_pos, self.config["model_path_prefix"]
108105
)
109-
predicator = constraint_util.FusibleSubgraphPredicator(
106+
predicator = fully_fusible_graph_predicator.FullyFusibleGraphPredicator(
110107
check_fusible_config
111108
)
109+
logger.warning("fully_fusible_graph_predicator-begin")
112110
success = predicator(model_path)
111+
logger.warning("fully_fusible_graph_predicator-end")
113112
if success:
114-
target_path = self._handle_success(
115-
temp_dir, start_pos, end_pos, os.path.basename(model_path)
116-
)
113+
target_path = self._handle_success(temp_dir, rel_model_path)
117114
print(
118115
f"SUCCESS in finding the biggest fully fusible subgraph. Result saved to: {target_path}"
119116
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from enum import Enum
2+
3+
4+
class GraphFusibility(Enum):
5+
kFullyFusible = "fully_fusible"
6+
kNotFullyFusible = "not_fully_fusible"
7+
8+
9+
class GraphFusibilityStatus(Exception):
10+
def __init__(self, graph_fusibility: GraphFusibility):
11+
message = f"{graph_fusibility=}"
12+
super().__init__(message)
13+
self.graph_fusibility = graph_fusibility

graph_net/torch/post_extract_process_count_kernels.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
import traceback
12
from graph_net.torch import utils
23
import importlib.util
34
import torch
45
import sys
56
from typing import Type
67
from torch.profiler import profile, record_function, ProfilerActivity
78

9+
from graph_net.torch.graph_fusibility_status import (
10+
GraphFusibilityStatus,
11+
GraphFusibility,
12+
)
813

9-
class GraphFullyFusible:
14+
15+
class ThrowExitStatusIfGraphFullyFusible:
1016
def __init__(self, config):
1117
self.config = config
1218

@@ -16,7 +22,7 @@ def __call__(self, model_path=None):
1622
# atexit.register(callback)
1723
torch._dynamo.reset()
1824
if model_path is None:
19-
sys.exit(1)
25+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
2026
# model
2127
model_class = load_class_from_file(
2228
f"{model_path}/model.py", class_name="GraphModule"
@@ -33,17 +39,36 @@ def __call__(self, model_path=None):
3339
try:
3440
model(**state_dict)
3541
except Exception:
36-
sys.exit(1)
42+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
3743
# try to compile the model
3844
try:
3945
compiled_model = torch.compile(model)
4046
except Exception:
41-
sys.exit(1)
47+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
4248
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
4349
if compiled_num_of_kernels == 1:
44-
sys.exit(0)
50+
raise GraphFusibilityStatus(GraphFusibility.kFullyFusible)
4551
else:
46-
sys.exit(1)
52+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
53+
54+
55+
class GraphFullyFusible:
56+
def __init__(self, config):
57+
self.predicator = ThrowExitStatusIfGraphFullyFusible(config)
58+
59+
def __call__(self, model_path=None):
60+
try:
61+
self.predicator(model_path)
62+
except GraphFusibilityStatus as status:
63+
if status.graph_fusibility == GraphFusibility.kFullyFusible:
64+
sys.exit(0)
65+
elif status.graph_fusibility == GraphFusibility.kNotFullyFusible:
66+
sys.exit(1)
67+
else:
68+
raise NotImplementedError(f"{status.graph_fusibility=}")
69+
except Exception:
70+
traceback.print_exc()
71+
sys.exit(1)
4772

4873

4974
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:

0 commit comments

Comments
 (0)