Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def forward(ctx, a, b, c, acc_dtype=None):
[512, 32768, 8192],
[1024, 28672, 8192],
[3072, 4096, 3072],
[4096, 4096, 4096],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a new combination breaks the CI and seems a bit out of topic for this pull request. Maybe we should move this change to a separate pull request?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Split that change in a separate PR please @LiyangLingIntel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than that LGTM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree, as this PR supposes to fix the 4k functional error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added 4k shape to XeTLA splitk list, benchmark CI works now.

],
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
Expand All @@ -152,17 +153,17 @@ def benchmark(M, N, K, provider):
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
quantiles=quantiles)
elif provider == 'triton':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='_kernel')
elif provider == 'xetla':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)

name = f'gemm_splitk_shape_{M}_{K}_{N}'
func = getattr(xetla_kernel, name)
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,17 @@ def benchmark(M, N, K, provider):
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
quantiles=quantiles)
elif provider == 'triton':
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
c = torch.zeros((M, N), device=a.device, dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles,
kernel_name=['first_wave', 'full_tiles'])
elif provider == 'xetla':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)

name = f'gemm_streamk_shape_{M}_{K}_{N}'
func = getattr(xetla_kernel, name)
Expand Down
Loading