Skip to content

Commit ca95a70

Browse files
authored
Improve GEMM performance of shape 4096x8x128x16384 (#2646)
This change (`grid` order adjustment to improve cache hit) originating from #2600. Batched gemm only. ~99% of XeTLA for `4096x8x128x16384`. ![image](https://github.com/user-attachments/assets/ef7e9750-b3f7-4adc-aa66-5be704383e40)
1 parent 85682e4 commit ca95a70

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def matmul_kernel_with_block_pointers_batched(
129129
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr,
130130
# Meta-parameters
131131
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
132-
bid = tl.program_id(axis=0)
133-
pid = tl.program_id(axis=1)
132+
bid = tl.program_id(axis=1)
133+
pid = tl.program_id(axis=0)
134134
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
135135
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
136136
num_pid_in_group = GROUP_SIZE_M * num_pid_n
@@ -186,8 +186,8 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
186186
B = a.shape[0]
187187
# 1D launch kernel where each block gets its own program.
188188
grid = lambda META: (
189-
B,
190189
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
190+
B,
191191
)
192192
matmul_kernel_with_block_pointers_batched[grid](
193193
a, b, c, #

0 commit comments

Comments
 (0)