Skip to content

Commit 6018c7b

Browse files
authored
[Benchmark] Run xetla streamk gemm in benchmark (#2438)
1 parent 35130dc commit 6018c7b

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import triton.language as tl
1111

1212
import triton_kernels_benchmark as benchmark_suit
13+
import xetla_kernel
1314

1415
if benchmark_suit.USE_IPEX_OPTION:
1516
import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -253,9 +254,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
253254
line_arg='provider',
254255
# argument name whose value corresponds to a different line in the plot
255256
# possible values for `line_arg``
256-
line_vals=['triton'],
257+
line_vals=['triton', 'xetla'],
257258
# label name for the lines
258-
line_names=['Triton'],
259+
line_names=['Triton', 'XeTLA'],
259260
# line styles
260261
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
261262
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
@@ -281,6 +282,20 @@ def benchmark(M, N, K, provider):
281282
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
282283
quantiles=quantiles,
283284
kernel_name=['first_wave', 'full_tiles'])
285+
elif provider == 'xetla':
286+
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
287+
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
288+
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
289+
290+
name = f'gemm_streamk_shape_{M}_{K}_{N}'
291+
func = getattr(xetla_kernel, name)
292+
xetla_fn = lambda: func(a, b, c, acc, cnt)
293+
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
294+
295+
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
296+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(
297+
xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
298+
kernel_name='gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k')
284299
else:
285300
raise NotImplementedError(f'Unsupported provider {provider}')
286301

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,10 @@ PYBIND11_MODULE(xetla_kernel, m) {
280280
&bf16_gemm<Test_4096x8x128x16384_row_row>, "bf16_gemm (XeTLA)");
281281
m.def("gemm_shape_4096_8_16384_128",
282282
&bf16_gemm<Test_4096x8x16384x128_row_row>, "bf16_gemm (XeTLA)");
283-
// flash_attn_fwd
283+
// gemm stream k
284+
m.def("gemm_streamk_shape_3072_4096_3072", &bf16_stream_k_gemm,
285+
"bf16_gemm_streamk (XeTLA)");
286+
// flash_attn
284287
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
285288
"flash attn fwd (XeTLA)");
286289
m.def("flash_attn_causal_true", &flash_attn<false, true, false>,

benchmarks/xetla_kernel/stream_k_gemm/stream_k_gemm.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ sycl::event stream_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc,
3636
using data_type_c = float;
3737
using data_type_acc = float;
3838

39-
auto context = queue.get_info<sycl::info::queue::context>();
40-
auto device = queue.get_info<sycl::info::queue::device>();
41-
4239
data_type_a *A = static_cast<data_type_a *>(_A);
4340
data_type_b *B = static_cast<data_type_b *>(_B);
4441
data_type_c *C = static_cast<data_type_c *>(_C);
@@ -52,7 +49,7 @@ sycl::event stream_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc,
5249
constexpr uint32_t sg_tile_k = 32;
5350

5451
// StreamK parameters - xecores available for stream_k dispatch
55-
uint32_t avail_xecores = 32;
52+
uint32_t avail_xecores = 64;
5653

5754
// Org the compute shape for sub-matrix
5855
using tile_shape =

0 commit comments

Comments
 (0)