Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,18 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no

function_events = prof.events()

functions = []
all_functions = []
if isinstance(kernel_name, str):
kernel_name = [kernel_name]
for ker_name in kernel_name:
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
functions = list(filter(lambda x: x.name.startswith(ker_name), function_events)) # pylint: disable=cell-var-from-loop
assert len(functions) == n_repeat, f"the profiling number for kernel: '{ker_name}' not match, {len(functions)}"
all_functions.append(functions)
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)

assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
# Make the time to the milliseconds.
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
times = torch.tensor([sum(map(lambda elem: elem.self_device_time_total, f)) * 1e-3 for f in zip(*all_functions)],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main problem was that the time of several kernels was not summed up. This affects only "gemm streamk" benchmark.

dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


Expand Down
5 changes: 2 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,8 @@ def benchmark(M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(
xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was an incorrect kernel name. The error was not visible because at the time of adding, the benchmark was not running in CI.

_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='stream_k_gemm_run')
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down