Skip to content

Commit 9bdca90

Browse files
committed
improve
1 parent 6da6e84 commit 9bdca90

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def matmul_kernel_with_block_pointers_batched(
129129
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr,
130130
# Meta-parameters
131131
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
132-
bid = tl.program_id(axis=0)
133-
pid = tl.program_id(axis=1)
132+
bid = tl.program_id(axis=1)
133+
pid = tl.program_id(axis=0)
134134
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
135135
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
136136
num_pid_in_group = GROUP_SIZE_M * num_pid_n
@@ -186,8 +186,8 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
186186
B = a.shape[0]
187187
# 1D launch kernel where each block gets its own program.
188188
grid = lambda META: (
189-
B,
190189
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
190+
B,
191191
)
192192
matmul_kernel_with_block_pointers_batched[grid](
193193
a, b, c, #

0 commit comments

Comments
 (0)