Skip to content

Commit fa6cc70

Browse files
committed
fix splitk assertion issue
1 parent 532728c commit fa6cc70

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def forward(ctx, a, b, c, acc_dtype=None):
128128
[512, 32768, 8192],
129129
[1024, 28672, 8192],
130130
[3072, 4096, 3072],
131+
[4096, 4096, 4096],
131132
],
132133
line_arg='provider',
133134
# argument name whose value corresponds to a different line in the plot
@@ -152,17 +153,17 @@ def benchmark(M, N, K, provider):
152153
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
153154
quantiles=quantiles)
154155
elif provider == 'triton':
155-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
156+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
156157
triton_fn = lambda: matmul(a, b, c)
157158
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
158159
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
159160
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
160161
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
161162
quantiles=quantiles, kernel_name='_kernel')
162163
elif provider == 'xetla':
163-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
164-
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
165-
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
164+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
165+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
166+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
166167

167168
name = f'gemm_splitk_shape_{M}_{K}_{N}'
168169
func = getattr(xetla_kernel, name)

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,17 +275,17 @@ def benchmark(M, N, K, provider):
275275
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
276276
quantiles=quantiles)
277277
elif provider == 'triton':
278-
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
278+
c = torch.zeros((M, N), device=a.device, dtype=torch.float32)
279279
triton_fn = lambda: matmul(a, b, c)
280280
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
281281
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
282282
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
283283
quantiles=quantiles,
284284
kernel_name=['first_wave', 'full_tiles'])
285285
elif provider == 'xetla':
286-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
287-
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
288-
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
286+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
287+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
288+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
289289

290290
name = f'gemm_streamk_shape_{M}_{K}_{N}'
291291
func = getattr(xetla_kernel, name)

0 commit comments

Comments
 (0)