Skip to content

Commit 6df0cd0

Browse files
committed
backup code
1 parent 93fabbf commit 6df0cd0

File tree

6 files changed

+136
-30
lines changed

6 files changed

+136
-30
lines changed

graph_net/test/torch_extractor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def forward(self, x):
7676
start_node_idx=0,
7777
end_node_idx=2,
7878
submodule_hook=submodule_hook,
79-
# group_head_and_tail=False,
79+
group_head_and_tail=True,
8080
)
8181
folded_output = folded(inp)
8282

graph_net/torch/decompose_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def fold_range_to_submodule(
181181
end_node_idx: int,
182182
submodule_hook=None,
183183
submodule_name="extracted_submodule",
184-
group_head_and_tail=True,
184+
group_head_and_tail=False,
185185
):
186186
return convert_to_submodules_graph(
187187
gm,

graph_net/torch/fully_fusible_graph_predicator.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1+
import torch
12
import traceback
23
import logging
4+
from graph_net.imp_util import load_module
5+
from graph_net.torch.decompose_util import fold_range_to_submodule
36
from graph_net.torch.graph_decomposer import NaiveDecomposerExtractor
47
from graph_net.torch.graph_fusibility_status import (
58
GraphFusibilityStatus,
69
GraphFusibility,
710
)
11+
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
12+
from graph_net.torch.fx_graph_cache_util import (
13+
parse_immutable_model_path_into_sole_graph_module,
14+
)
815

916
logger = logging.getLogger(__name__)
1017

@@ -32,3 +39,64 @@ def __call__(self, model_path):
3239
traceback.print_exc()
3340
print("--------------------------\n")
3441
return False
42+
43+
44+
class FullyFusibleSubGraphPredicator:
45+
def __init__(self, config):
46+
if config is None:
47+
config = {}
48+
self.config = self._make_config(config)
49+
self.nn_module_fully_fusible_decorator = (
50+
self._make_nn_module_fully_fusible_decorator(config)
51+
)
52+
model_path = self.config["model_path"]
53+
module, inputs = get_torch_module_and_inputs(model_path)
54+
self.traced_module = parse_immutable_model_path_into_sole_graph_module(
55+
model_path
56+
)
57+
self.inputs = inputs
58+
59+
def _make_nn_module_fully_fusible_decorator(self, config):
60+
py_module = load_module(self.config["nn_module_fully_fusible_decorator_path"])
61+
decorator_cls = getattr(
62+
py_module, self.config["nn_module_fully_fusible_decorator_class_name"]
63+
)
64+
return decorator_cls(self.config["nn_module_fully_fusible_decorator_config"])
65+
66+
def _make_config(
67+
self,
68+
model_path,
69+
nn_module_fully_fusible_decorator_path,
70+
nn_module_fully_fusible_decorator_class_name,
71+
nn_module_fully_fusible_decorator_config=None,
72+
):
73+
if nn_module_fully_fusible_decorator_config is None:
74+
nn_module_fully_fusible_decorator_config = {}
75+
return {
76+
"model_path": model_path,
77+
"nn_module_fully_fusible_decorator_path": nn_module_fully_fusible_decorator_path,
78+
"nn_module_fully_fusible_decorator_class_name": nn_module_fully_fusible_decorator_class_name,
79+
"nn_module_fully_fusible_decorator_config": nn_module_fully_fusible_decorator_config,
80+
}
81+
82+
def __call__(self, gm: torch.fx.GraphModule, start_node_idx, end_node_idx):
83+
try:
84+
rewrited_gm: torch.fx.GraphModule = fold_range_to_submodule(
85+
gm,
86+
start_node_idx=start_node_idx,
87+
end_node_idx=end_node_idx,
88+
submodule_hook=self.nn_module_fully_fusible_decorator,
89+
)
90+
rewrited_gm(*self.inputs)
91+
except GraphFusibilityStatus as status:
92+
if status.graph_fusibility == GraphFusibility.kFullyFusible:
93+
return True
94+
elif status.graph_fusibility == GraphFusibility.kNotFullyFusible:
95+
return False
96+
else:
97+
raise NotImplementedError(f"{status.graph_fusibility=}")
98+
except Exception:
99+
print("\n--- Custom Error Handler ---")
100+
traceback.print_exc()
101+
print("--------------------------\n")
102+
return False

graph_net/torch/fully_fusible_subgraph_extractor.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,28 @@ def __call__(self, rel_model_path):
103103
check_fusible_config = self._build_decompose_config(
104104
temp_dir, start_pos, end_pos, self.config["model_path_prefix"]
105105
)
106-
predicator = fully_fusible_graph_predicator.FullyFusibleGraphPredicator(
107-
check_fusible_config
106+
predicator_cls = (
107+
fully_fusible_graph_predicator.FullyFusibleGraphPredicator
108108
)
109+
predicator = predicator_cls(check_fusible_config)
109110
logger.warning("fully_fusible_graph_predicator-begin")
110111
success = predicator(model_path)
111112
logger.warning("fully_fusible_graph_predicator-end")
112-
if success:
113-
target_path = self._handle_success(temp_dir, rel_model_path)
114-
print(
115-
f"SUCCESS in finding the biggest fully fusible subgraph. Result saved to: {target_path}"
116-
)
117-
break
113+
if not success:
114+
continue
115+
decomposer_config = self._build_decompose_config(
116+
temp_dir, start_pos, end_pos, self.config["model_path_prefix"]
117+
)
118+
predicator_cls = (
119+
fully_fusible_graph_predicator.FullyFusibleGraphPredicator
120+
)
121+
predicator = predicator_cls(decomposer_config)
122+
predicator(model_path)
123+
target_path = self._handle_success(temp_dir, rel_model_path)
124+
print(
125+
f"SUCCESS in finding the biggest fully fusible subgraph. Result saved to: {target_path}"
126+
)
127+
break
118128
else:
119129
logger.warning("fail to find fully fusible subgraph")
120130
return gm.forward

graph_net/torch/graph_decomposer.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
66
import graph_net.imp_util as imp_util
77
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
8+
from graph_net.torch.fx_graph_cache_util import (
9+
parse_immutable_model_path_into_sole_graph_module,
10+
)
811
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
912
import logging
1013

@@ -134,30 +137,24 @@ def _make_config(
134137
}
135138

136139
def __call__(self, rel_model_path):
137-
# callback = lambda: logger.warning("NaiveDecomposerExtractor-call-end")
138-
# logger.warning("NaiveDecomposerExtractor-call-begin")
139-
# atexit.register(callback)
140140
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
141141
config = {
142142
k: v
143143
for k, v in self.config.items()
144144
if k in {"split_positions", "group_head_and_tail", "chain_style"}
145145
}
146-
# logger.warning("get_torch_module_and_inputs_begin")
147146
module, inputs = get_torch_module_and_inputs(model_path)
148-
# logger.warning("get_torch_module_and_inputs_end")
149-
# logger.warning("parse_sole_graph_module_begin")
150-
gm = parse_sole_graph_module(module, inputs)
151-
# logger.warning("parse_sole_graph_module_end")
152-
# callback = lambda: logger.warning("convert_to_submodules_graph-call-end")
153-
# logger.warning("convert_to_submodules_graph-call-begin")
154-
# atexit.register(callback)
155-
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
156-
gm,
157-
submodule_hook=self.get_naive_decomposer_extractor(model_path),
158-
**config,
159-
)
160-
rewrited_gm(*inputs)
147+
gm = parse_immutable_model_path_into_sole_graph_module(model_path)
148+
try:
149+
logger.warning("convert_to_submodules_graph-call-begin")
150+
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
151+
gm,
152+
submodule_hook=self.get_naive_decomposer_extractor(model_path),
153+
**config,
154+
)
155+
rewrited_gm(*inputs)
156+
finally:
157+
logger.warning("convert_to_submodules_graph-call-end")
161158

162159
def get_naive_decomposer_extractor(self, model_path):
163160
def fn(submodule, seq_no):

graph_net/torch/post_extract_process_count_kernels.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,31 @@
1212
)
1313

1414

15+
class TorchNNModuleFullyFusibleDecorator:
16+
def __init__(self, config):
17+
self.config = config
18+
19+
def __call__(self, module):
20+
return TorchNNModuleFullyFusiblePredicator(module)
21+
22+
23+
class TorchNNModuleFullyFusiblePredicator(torch.nn.Module):
24+
def __init__(self, module):
25+
self.module = module
26+
27+
def forward(self, *inputs):
28+
try:
29+
compiled_model = torch.compile(self.module)
30+
except Exception:
31+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
32+
ret_tensors, compiled_num_of_kernels = count_kernels(compiled_model, inputs)
33+
if compiled_num_of_kernels == 1:
34+
raise GraphFusibilityStatus(GraphFusibility.kFullyFusible)
35+
else:
36+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
37+
return ret_tensors
38+
39+
1540
class ThrowExitStatusIfGraphFullyFusible:
1641
def __init__(self, config):
1742
self.config = config
@@ -45,7 +70,7 @@ def __call__(self, model_path=None):
4570
compiled_model = torch.compile(model)
4671
except Exception:
4772
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
48-
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
73+
_, compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
4974
if compiled_num_of_kernels == 1:
5075
raise GraphFusibilityStatus(GraphFusibility.kFullyFusible)
5176
else:
@@ -103,11 +128,17 @@ def count_kernels(model, sample_inputs) -> int:
103128
record_shapes=True,
104129
) as prof:
105130
with record_function("model_inference"):
106-
_ = model(**sample_inputs)
131+
if isinstance(sample_inputs, dict):
132+
ret_tensors = model(**sample_inputs)
133+
elif isinstance(sample_inputs, (list, tuple)):
134+
ret_tensors = model(*sample_inputs)
135+
else:
136+
raise NotImplementedError(f"{type(sample_inputs)=}")
137+
107138
events = prof.key_averages()
108139

109140
total_count = 0
110141
for e in events:
111142
if e.key == "cuLaunchKernel" or e.key == "cudaLaunchKernel":
112143
total_count += e.count
113-
return total_count
144+
return ret_tensors, total_count

0 commit comments

Comments
 (0)