Skip to content

Commit 822fbf8

Browse files
committed
CountNumKernelsNNModule
1 parent 9b6af4c commit 822fbf8

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

graph_net/optional.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import TypeVar, Generic, Union
2+
3+
T = TypeVar("T")
4+
5+
6+
class Optional(Generic[T]):
7+
def __init__(self, value: Union[T, None]):
8+
self._value = value
9+
10+
def reset(self, that):
11+
assert isinstance(that, Optional)
12+
self._value = that._value
13+
14+
def is_some(self) -> bool:
15+
return self._value is not None
16+
17+
def unwrap(self) -> T:
18+
"""Returns the value or raises an error if None."""
19+
if self._value is None:
20+
raise ValueError("Tried to unwrap a None value!")
21+
return self._value
22+
23+
def unwrap_or(self, default: T) -> T:
24+
"""Returns the value or a default if None."""
25+
return self._value if self._value is not None else default

graph_net/torch/count_kernels_util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from graph_net.torch import utils
33
import importlib.util
44
import torch
5+
from graph_net.optional import Optional
56
import sys
67
from typing import Type
78
from torch.profiler import profile, record_function, ProfilerActivity
@@ -20,6 +21,21 @@ def __call__(self, module, sub_module_idx):
2021
return TorchNNModuleFullyFusiblePredicator(module)
2122

2223

24+
class CountNumKernelsNNModule(torch.nn.Module):
25+
def __init__(self, module, mut_opt_num_kernels: Optional):
26+
super().__init__()
27+
self.module = module
28+
self.compiled_module = torch.compile(self.module)
29+
self.mut_opt_num_kernels = mut_opt_num_kernels
30+
31+
def forward(self, *inputs):
32+
ret_tensors, compiled_num_of_kernels = count_kernels(
33+
self.compiled_module, inputs
34+
)
35+
self.mut_opt_num_kernels.reset(Optional(compiled_num_of_kernels))
36+
return ret_tensors
37+
38+
2339
class TorchNNModuleFullyFusiblePredicator(torch.nn.Module):
2440
def __init__(self, module):
2541
super().__init__()

graph_net/torch/sample_passes/cumsum_num_kernels_generator.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from graph_net.sample_pass.sample_pass import SamplePass
22
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
3+
from graph_net.optional import Optional
34
from graph_net.torch.fx_graph_cache_util import (
45
parse_immutable_model_path_into_sole_graph_module,
56
)
67
from graph_net.torch.decompose_util import convert_to_submodules_graph
7-
from graph_net.torch.count_kernels_util import count_kernels
8+
from graph_net.torch.count_kernels_util import CountNumKernelsNNModule
89
from graph_net.torch.fx_graph_module_util import (
910
get_fx_graph_num_ops,
1011
get_torch_module_and_inputs,
@@ -41,6 +42,7 @@ def resume(self, rel_model_path: str):
4142
cumsum_num_kernels = analyzer.analyze()
4243
cumsum_num_kernels_json = json.dumps(cumsum_num_kernels, indent=4)
4344
output_dir_path = Path(self.config["output_dir"]) / rel_model_path
45+
output_dir_path.mkdir(parents=True, exist_ok=True)
4446
(output_dir_path / self.config["output_json_file_name"]).write_text(
4547
cumsum_num_kernels_json
4648
)
@@ -53,9 +55,9 @@ def __init__(self, model_path: Path):
5355
def analyze(self):
5456
triples = list(self._get_cumsum_num_kernels())
5557
data = {
56-
"range_and_num_kernels": [
57-
((start, end), num_kernels) for start, end, num_kernels in triples
58-
],
58+
"num_kernels": [num_kernels for start, end, num_kernels in triples],
59+
"starts": [start for start, end, num_kernels in triples],
60+
"ends": [end for start, end, num_kernels in triples],
5961
}
6062
return data
6163

@@ -79,16 +81,22 @@ def _get_num_kernels_if_submodule_compiled(
7981
self, graph_module, nn_module, inputs, submodule_start, submodule_end
8082
):
8183
torch.cuda.empty_cache()
84+
mut_opt_num_kernels = Optional(None)
85+
86+
def compile_and_count_num_kernels(m, seq_no):
87+
return CountNumKernelsNNModule(m, mut_opt_num_kernels)
88+
8289
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
8390
graph_module,
84-
submodule_hook=lambda m, seq_no: torch.compile(m),
91+
submodule_hook=compile_and_count_num_kernels,
8592
split_positions=[submodule_start, submodule_end],
8693
subgraph_ranges=[(submodule_start, submodule_end)],
8794
group_head_and_tail=False,
8895
chain_style=False,
8996
)
90-
_, num_kernels = count_kernels(rewrited_gm, inputs)
91-
return num_kernels
97+
rewrited_gm(*inputs)
98+
assert mut_opt_num_kernels.is_some()
99+
return mut_opt_num_kernels.unwrap()
92100

93101
def _get_ranges(self, gm):
94102
num_ops = get_fx_graph_num_ops(gm)

0 commit comments

Comments
 (0)