Skip to content

Commit 9b6af4c

Browse files
committed
add sample pass cumsum_num_kernels_generator
1 parent d0d9958 commit 9b6af4c

File tree

3 files changed

+117
-0
lines changed

3 files changed

+117
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
5+
python3 -m graph_net.model_path_handler \
6+
--model-path-list "$GRAPH_NET_ROOT/graph_net/test/dev_model_list/cumsum_num_kernels_sample_list.txt" \
7+
--handler-config $(base64 -w 0 <<EOF
8+
{
9+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/cumsum_num_kernels_generator.py",
10+
"handler_class_name": "CumSumNumKernelsGenerator",
11+
"handler_config": {
12+
"output_json_file_name": "cumsum_num_kernels.json",
13+
"resume": false,
14+
"model_path_prefix": "$GRAPH_NET_ROOT",
15+
"output_dir": "/tmp/cumsum_num_kernels_workspace"
16+
}
17+
}
18+
EOF
19+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
samples/timm/resnet18
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from graph_net.sample_pass.sample_pass import SamplePass
2+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
3+
from graph_net.torch.fx_graph_cache_util import (
4+
parse_immutable_model_path_into_sole_graph_module,
5+
)
6+
from graph_net.torch.decompose_util import convert_to_submodules_graph
7+
from graph_net.torch.count_kernels_util import count_kernels
8+
from graph_net.torch.fx_graph_module_util import (
9+
get_fx_graph_num_ops,
10+
get_torch_module_and_inputs,
11+
)
12+
from pathlib import Path
13+
import json
14+
import torch
15+
16+
17+
class CumSumNumKernelsGenerator(SamplePass, ResumableSamplePassMixin):
18+
def __init__(self, config):
19+
super().__init__(config)
20+
21+
def declare_config(
22+
self,
23+
model_path_prefix: str,
24+
output_dir: str,
25+
resume: bool = False,
26+
limits_handled_models: int = None,
27+
output_json_file_name: str = "cumsum_num_kernels.json",
28+
):
29+
pass
30+
31+
def __call__(self, rel_model_path: str):
32+
self.resumable_handle_sample(rel_model_path)
33+
34+
def sample_handled(self, rel_model_path: str) -> bool:
35+
file_name = self.config["output_json_file_name"]
36+
return self.naive_sample_handled(rel_model_path, search_file_name=file_name)
37+
38+
def resume(self, rel_model_path: str):
39+
model_path = Path(self.config["model_path_prefix"]) / rel_model_path
40+
analyzer = CumsumNumKernelsAnalyzer(model_path)
41+
cumsum_num_kernels = analyzer.analyze()
42+
cumsum_num_kernels_json = json.dumps(cumsum_num_kernels, indent=4)
43+
output_dir_path = Path(self.config["output_dir"]) / rel_model_path
44+
(output_dir_path / self.config["output_json_file_name"]).write_text(
45+
cumsum_num_kernels_json
46+
)
47+
48+
49+
class CumsumNumKernelsAnalyzer:
50+
def __init__(self, model_path: Path):
51+
self.model_path = model_path
52+
53+
def analyze(self):
54+
triples = list(self._get_cumsum_num_kernels())
55+
data = {
56+
"range_and_num_kernels": [
57+
((start, end), num_kernels) for start, end, num_kernels in triples
58+
],
59+
}
60+
return data
61+
62+
def _get_cumsum_num_kernels(self):
63+
model_path = str(self.model_path)
64+
module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False)
65+
gm = parse_immutable_model_path_into_sole_graph_module(model_path)
66+
for start, end in self._get_ranges(gm):
67+
assert start == 0
68+
num_kernels = self._get_num_kernels_if_submodule_compiled(
69+
graph_module=gm,
70+
nn_module=module,
71+
inputs=inputs,
72+
submodule_start=start,
73+
submodule_end=end,
74+
)
75+
print(f"subgraph_range=[{start}, {end})\t{num_kernels=}")
76+
yield start, end, num_kernels
77+
78+
def _get_num_kernels_if_submodule_compiled(
79+
self, graph_module, nn_module, inputs, submodule_start, submodule_end
80+
):
81+
torch.cuda.empty_cache()
82+
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
83+
graph_module,
84+
submodule_hook=lambda m, seq_no: torch.compile(m),
85+
split_positions=[submodule_start, submodule_end],
86+
subgraph_ranges=[(submodule_start, submodule_end)],
87+
group_head_and_tail=False,
88+
chain_style=False,
89+
)
90+
_, num_kernels = count_kernels(rewrited_gm, inputs)
91+
return num_kernels
92+
93+
def _get_ranges(self, gm):
94+
num_ops = get_fx_graph_num_ops(gm)
95+
for i in range(num_ops):
96+
cum_num_ops = i + 1
97+
yield 0, cum_num_ops

0 commit comments

Comments
 (0)