Skip to content

Commit 700abe3

Browse files
authored
Fix upstream profiler for several kernels (#2498)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent f4fdd8f commit 700abe3

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,18 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
213213

214214
function_events = prof.events()
215215

216-
functions = []
216+
all_functions = []
217217
if isinstance(kernel_name, str):
218218
kernel_name = [kernel_name]
219219
for ker_name in kernel_name:
220-
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
220+
functions = list(filter(lambda x: x.name.startswith(ker_name), function_events)) # pylint: disable=cell-var-from-loop
221+
assert len(functions) == n_repeat, f"the profiling number for kernel: '{ker_name}' not match, {len(functions)}"
222+
all_functions.append(functions)
221223
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
222224

223-
assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
224225
# Make the time to the milliseconds.
225-
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
226+
times = torch.tensor([sum(map(lambda elem: elem.self_device_time_total, f)) * 1e-3 for f in zip(*all_functions)],
227+
dtype=torch.float)
226228
return _summarize_statistics(times, quantiles, return_mode)
227229

228230

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,8 @@ def benchmark(M, N, K, provider):
293293
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
294294

295295
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
296-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(
297-
xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
298-
kernel_name='gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k')
296+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
297+
quantiles=quantiles, kernel_name='stream_k_gemm_run')
299298
else:
300299
raise NotImplementedError(f'Unsupported provider {provider}')
301300

0 commit comments

Comments
 (0)