From 9bdca906cb4c301a533e90a7124dbc5abb498d89 Mon Sep 17 00:00:00 2001 From: ESI-SYD Date: Thu, 7 Nov 2024 04:06:31 +0000 Subject: [PATCH] improve --- benchmarks/triton_kernels_benchmark/gemm_benchmark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index da41f1e447..9941b0c5f0 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -129,8 +129,8 @@ def matmul_kernel_with_block_pointers_batched( stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): - bid = tl.program_id(axis=0) - pid = tl.program_id(axis=1) + bid = tl.program_id(axis=1) + pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -186,8 +186,8 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False): B = a.shape[0] # 1D launch kernel where each block gets its own program. grid = lambda META: ( - B, triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + B, ) matmul_kernel_with_block_pointers_batched[grid]( a, b, c, #