Skip to content

Commit ca860b3

Browse files
committed
modify the way of counting kernels used
1 parent fe89add commit ca860b3

File tree

4 files changed

+17
-31
lines changed

4 files changed

+17
-31
lines changed

graph_net/test/naive_decomposer_and_post_extract_process_test.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@ decorator_config_json_str=$(cat <<EOF
2424
"group_head_and_tail": true,
2525
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
2626
"filter_config": {},
27-
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process.py",
28-
"post_extract_process_config": {
29-
"decorator_path": "$GRAPH_NET_ROOT/torch/shape_prop.py",
30-
"decorator_class_name": "ShapePropagate"
31-
}
27+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process.py"
3228
}
3329
}
3430
}

graph_net/torch/naive_graph_decomposer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def forward(self, *args):
9292
if not self.extracted:
9393
if self.need_extract(self.submodule, args):
9494
self.builtin_extractor(self.submodule, args)
95-
self.get_post_extract_process()
95+
self._post_extract_process()
9696
self.extracted = True
9797
return self.submodule(*args)
9898

@@ -101,7 +101,7 @@ def need_extract(self, gm, sample_inputs):
101101
return True
102102
return self.filter(gm, sample_inputs)
103103

104-
def get_post_extract_process(self):
104+
def _post_extract_process(self):
105105
model_path = os.path.join(
106106
self.parent_graph_extractor.config["output_dir"], self.modelname
107107
)
@@ -114,7 +114,7 @@ def make_filter(self, config):
114114
return module.GraphFilter(config["filter_config"])
115115

116116
def make_post_extract_process(self, config):
117-
if config["filter_path"] is None:
117+
if config["post_extract_process_path"] is None:
118118
return None
119119
module = imp_util.load_module(config["post_extract_process_path"])
120120
return module.PostExtractProcess(config["post_extract_process_config"])

graph_net/torch/naive_subgraph_filter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ def __init__(self, config):
33
self.config = config
44

55
def __call__(self, gm, sample_inputs):
6-
# print(f"GraphFilter\n{gm.code}")
76
return True

graph_net/torch/post_extract_process.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
11
from graph_net.torch import utils
2-
import argparse
32
import importlib.util
4-
import inspect
53
import shutil
64
import torch
7-
import logging
8-
from pathlib import Path
9-
from typing import Type, Any
10-
import sys
11-
import json
12-
import base64
13-
from contextlib import contextmanager
14-
5+
from typing import Type
156
from torch.profiler import profile, record_function, ProfilerActivity
167

178

@@ -34,7 +25,9 @@ def __call__(self, model_path=None):
3425
params = inputs_params["weight_info"]
3526
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
3627

37-
compiled_num_of_kernels = compile_and_count_kernels(model, state_dict)
28+
model(**state_dict)
29+
compiled_model = torch.compile(model)
30+
compiled_num_of_kernels = count_kernels(model, state_dict)
3831
if compiled_num_of_kernels == 1:
3932
print(model_path, "can be fully integrated")
4033
return True
@@ -52,12 +45,12 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
5245
return model_class
5346

5447

55-
def compile_and_count_kernels(gm, sample_inputs) -> int:
48+
def count_kernels(model, sample_inputs) -> int:
5649
"""
5750
Count the number of CUDA kernel launches performed during a model's forward pass.
5851
5952
Args:
60-
gm(graph models)
53+
model(graph models)
6154
sample_inputs(tensors)
6255
6356
Returns:
@@ -68,21 +61,19 @@ def compile_and_count_kernels(gm, sample_inputs) -> int:
6861
- Identifies the event with key = 'cudaLaunchKernel', which corresponds
6962
to the number of CUDA kernel launches.
7063
"""
71-
gm.eval()
64+
model.eval()
7265
# Use PyTorch Profiler
73-
compiled_gm = torch.compile(gm)
74-
_ = compiled_gm(**sample_inputs)
7566

7667
with profile(
7768
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU],
7869
record_shapes=True,
7970
) as prof:
8071
with record_function("model_inference"):
81-
output = compiled_gm(**sample_inputs)
72+
output = model(**sample_inputs)
8273
events = prof.key_averages()
83-
if_compile_work = any(e.key == "TorchDynamo Cache Lookup" for e in events)
84-
if not if_compile_work:
85-
return -1
74+
75+
total_count = 0
8676
for e in events:
87-
if e.key == "cuLaunchKernel":
88-
return e.count
77+
if e.key == "cuLaunchKernel" or e.key == "cudalaunchKernel":
78+
total_count += e.count
79+
return total_count

0 commit comments

Comments
 (0)