Skip to content

Commit 8abe554

Browse files
committed
Remove workaround for upstream profiler
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent fe45283 commit 8abe554

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N
149149

150150

151151
def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean",
152-
device="xpu", sync_submitting=True, kernel_name=None):
152+
device="xpu", sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument
153153
"""
154154
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
155155
the 20-th and 80-th performance percentile.
@@ -168,7 +168,7 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None
168168

169169
assert return_mode in ["min", "max", "mean", "median"]
170170
import torch
171-
from torch.profiler import profile, ProfilerActivity
171+
from torch.profiler import profile, ProfilerActivity, record_function
172172

173173
fn()
174174
synchronize()
@@ -210,22 +210,24 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None
210210
if sync_submitting:
211211
synchronize()
212212
# record time of `fn`
213-
fn()
213+
with record_function("__profile_kernel_of_func"):
214+
fn()
214215
# Record clocks
215216
synchronize()
216217

217-
function_events = prof.events()
218+
profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), prof.events())
219+
functions = list(profiling_func_filter)
218220

219-
functions = []
220-
if isinstance(kernel_name, str):
221-
kernel_name = [kernel_name]
222-
for ker_name in kernel_name:
223-
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
224-
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
221+
def extract_kernels(funcs):
222+
kernels = []
223+
kernels += list(itertools.chain.from_iterable(map(lambda func: extract_kernels(func.cpu_children), funcs)))
224+
kernels += list(itertools.chain.from_iterable([func.kernels for func in funcs]))
225+
return kernels
225226

226-
assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
227+
kernels = [extract_kernels(func.cpu_children) for func in functions]
228+
assert len(kernels) == n_repeat, "the profiling number not match"
227229
# Make the time to the milliseconds.
228-
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
230+
times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float)
229231
return _summarize_statistics(times, quantiles, return_mode)
230232

231233

third_party/intel/backend/driver.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,26 @@ class CompilationHelper:
7171
def __init__(self):
7272
self._library_dir = None
7373
self._include_dir = None
74-
self.libraries = ['ze_loader', 'sycl']
74+
self.libraries = ['ze_loader', 'sycl', 'torch']
7575

7676
@cached_property
7777
def _compute_compilation_options_lazy(self):
78+
import torch
7879
ze_root = os.getenv("ZE_PATH", default="/usr/local")
7980
include_dir = [os.path.join(ze_root, "include")]
8081

8182
include_dir, library_dir = find_sycl(include_dir)
8283

8384
dirname = os.path.dirname(os.path.realpath(__file__))
8485
include_dir += [os.path.join(dirname, "include")]
86+
include_dir += [
87+
os.path.join(torch.utils.cmake_prefix_path, "../../include"),
88+
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include"),
89+
]
8590
library_dir += [os.path.join(dirname, "lib")]
91+
library_dir += [
92+
os.path.join(torch.utils.cmake_prefix_path, "../../lib"),
93+
]
8694

8795
self._library_dir = library_dir
8896
self._include_dir = include_dir
@@ -218,6 +226,7 @@ def format_of(ty):
218226
#include <iomanip>
219227
#include <level_zero/ze_api.h>
220228
#include <sycl/sycl.hpp>
229+
#include <ATen/record_function.h>
221230
222231
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
223232
#include <Python.h>
@@ -310,6 +319,7 @@ def format_of(ty):
310319
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
311320
312321
std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
322+
RECORD_FUNCTION("XPU Triton kernel: " + kernel_name, {{}});
313323
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
314324
uint32_t num_params = sizeof(params)/sizeof(params[0]);
315325
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();

0 commit comments

Comments
 (0)