diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 2ad8636414..7944d8d301 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -227,28 +227,28 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b): @benchmark_suit.perf_report( benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot - x_names=['B', 'M', 'K', 'N'], + x_names=['B', 'M', 'N', 'K'], # different possible values for `x_name` x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + # [ # - [1, 1, 5120, 13824], # - [1, 4, 4096, 12288], # + [1, 1, 13824, 5120], # + [1, 4, 12288, 4096], # [1, 512, 8192, 8192], # [1, 512, 8192, 32768], # [1, 512, 32768, 8192], # - [1, 1024, 16384, 8192], # - [1, 1024, 28672, 8192], # - [1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance - [1, 4096, 16384, 8192], # - [1, 8192, 16384, 1024], # - [1, 8192, 16384, 4096], # + [1, 1024, 8192, 16384], # + [1, 1024, 8192, 28672], # + [1, 3072, 3072, 4096], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance + [1, 4096, 8192, 16384], # + [1, 8192, 1024, 16384], # + [1, 8192, 4096, 16384], # [1, 16384, 1024, 8192], # [1, 16384, 4096, 8192], # [1, 16384, 8192, 1024], # [1, 16384, 8192, 4096], # [4, 32768, 128, 4096], # [4, 32768, 4096, 128], # - [32, 4096, 4096, 128], # + [32, 4096, 128, 4096], # [4096, 8, 128, 16384], # [4096, 8, 16384, 128] ], @@ -268,6 +268,7 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b): def benchmark(B, M, N, K, provider): a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B) + torch.manual_seed(0) a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16) b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16) @@ -291,10 +292,10 @@ def benchmark(B, M, N, K, provider): elif provider == 'triton': assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: - c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32) else: assert len(a.shape) == 2, 'Expecting shape of length 2' - c = torch.empty((M, N), device='xpu', dtype=torch.float32) + c = torch.zeros((M, N), device='xpu', dtype=torch.float32) triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B) torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 @@ -304,17 +305,17 @@ def benchmark(B, M, N, K, provider): kernel_name='matmul_kernel_with_block_pointers') elif provider == 'xetla': if B == 1: - c = torch.empty((M, N), device='xpu', dtype=torch.float32) - acc = torch.empty((M, N), device='xpu', dtype=torch.float32) - cnt = torch.empty((M, N), device='xpu', dtype=torch.int32) + c = torch.zeros((M, N), device='xpu', dtype=torch.float32) + acc = torch.zeros((M, N), device='xpu', dtype=torch.float32) + cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32) else: - c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) - acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32) - cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32) + c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32) + acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32) + cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32) name = f'gemm_shape_{B}_{M}_{K}_{N}' # FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get # better performance. - if (B, M, N, K) == (1, 3072, 4096, 3072): + if (B, M, N, K) == (1, 3072, 3072, 4096): name = 'gemm_streamk_shape_3072_4096_3072' func = getattr(xetla_kernel, name) xetla_fn = lambda: func(a, b, c, acc, cnt)