Skip to content

Commit 0a5752e

Browse files
committed
Revert "Improve GEMM performance of shape 4096x8x128x16384 (#2646)"
This reverts commit ca95a70.
1 parent e8b34a0 commit 0a5752e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def matmul_kernel_with_block_pointers_batched(
129129
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr,
130130
# Meta-parameters
131131
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
132-
bid = tl.program_id(axis=1)
133-
pid = tl.program_id(axis=0)
132+
bid = tl.program_id(axis=0)
133+
pid = tl.program_id(axis=1)
134134
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
135135
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
136136
num_pid_in_group = GROUP_SIZE_M * num_pid_n
@@ -186,8 +186,8 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
186186
B = a.shape[0]
187187
# 1D launch kernel where each block gets its own program.
188188
grid = lambda META: (
189-
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
190189
B,
190+
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
191191
)
192192
matmul_kernel_with_block_pointers_batched[grid](
193193
a, b, c, #

0 commit comments

Comments
 (0)