Skip to content

Commit 05a7055

Browse files
Change to use tensor descriptor for more benchmarks (#5219)
This PR modernizes Triton kernel implementations by replacing block pointers with tensor descriptors across four GEMM benchmark files. This change aligns with newer Triton APIs and improves code readability. Closes #4318, #4320, #4317, #4319 --------- Co-authored-by: He, Dan H <[email protected]>
1 parent 52491a4 commit 05a7055

File tree

4 files changed

+105
-119
lines changed

4 files changed

+105
-119
lines changed

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def suffix():
5555
key=['M', 'N', 'K'],
5656
)
5757
@triton.jit
58-
def matmul_kernel_with_block_pointers(
58+
def matmul_kernel_with_tensor_descriptors(
5959
# Pointers to matrices
6060
a_ptr, b_ptr, c_ptr, d_ptr,
6161
# Matrix dimensions
@@ -78,31 +78,27 @@ def matmul_kernel_with_block_pointers(
7878
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
7979
pid_n = (pid % num_pid_in_group) // group_size_m
8080

81-
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
82-
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
83-
order=(1, 0))
84-
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
85-
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
86-
order=(1, 0))
81+
a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
82+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
83+
b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
84+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
8785

8886
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
87+
off_k = 0
8988
for _ in range(0, K, BLOCK_SIZE_K):
90-
a = tl.load(a_block_ptr, boundary_check=(0, 1))
91-
b = tl.load(b_block_ptr, boundary_check=(0, 1))
89+
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
90+
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
9291
accumulator += tl.dot(a, b)
93-
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
94-
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
92+
off_k += BLOCK_SIZE_K
9593

96-
d_block_ptr = tl.make_block_ptr(base=d_ptr, shape=(M, N), strides=(stride_dm, stride_dn),
97-
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
98-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
99-
d = tl.load(d_block_ptr, boundary_check=(0, 1))
94+
d_desc = tl.make_tensor_descriptor(base=d_ptr, shape=(M, N), strides=(stride_dm, stride_dn),
95+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
96+
d = d_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N])
10097
c = accumulator + d
10198

102-
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
103-
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
104-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
105-
tl.store(c_block_ptr, c, boundary_check=(0, 1))
99+
c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
100+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
101+
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
106102

107103

108104
# pylint: disable=unused-argument
@@ -130,7 +126,7 @@ def matmul_kernel_with_block_pointers(
130126
key=['M', 'N', 'K'],
131127
)
132128
@triton.jit
133-
def matmul_kernel_with_block_pointers_batched(
129+
def matmul_kernel_with_tensor_descriptors_batched(
134130
# Pointers to matrices
135131
a_ptr, b_ptr, c_ptr, d_ptr,
136132
# Matrix dimensions
@@ -157,33 +153,30 @@ def matmul_kernel_with_block_pointers_batched(
157153
offset_a = bid.to(tl.int64) * stride_az
158154
offset_b = bid.to(tl.int64) * stride_bz
159155

160-
a_block_ptr = tl.make_block_ptr(base=a_ptr + offset_a, shape=(M, K), strides=(stride_am, stride_ak),
161-
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
162-
order=(1, 0))
163-
b_block_ptr = tl.make_block_ptr(base=b_ptr + offset_b, shape=(K, N), strides=(stride_bk, stride_bn),
164-
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
165-
order=(1, 0))
156+
a_desc = tl.make_tensor_descriptor(base=a_ptr + offset_a, shape=(M, K), strides=(stride_am, stride_ak),
157+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
158+
b_desc = tl.make_tensor_descriptor(base=b_ptr + offset_b, shape=(K, N), strides=(stride_bk, stride_bn),
159+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
166160

167161
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
162+
off_k = 0
168163
for _ in range(0, K, BLOCK_SIZE_K):
169-
a = tl.load(a_block_ptr, boundary_check=(0, 1))
170-
b = tl.load(b_block_ptr, boundary_check=(0, 1))
164+
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
165+
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
171166
accumulator += tl.dot(a, b)
172-
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
173-
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
167+
off_k += BLOCK_SIZE_K
174168

175169
offset_d = bid.to(tl.int64) * stride_dz
176-
d_block_ptr = tl.make_block_ptr(base=d_ptr + offset_d, shape=(M, N), strides=(stride_dm, stride_dn),
177-
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
178-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
179-
d = tl.load(d_block_ptr, boundary_check=(0, 1))
170+
d_desc = tl.make_tensor_descriptor(base=d_ptr + offset_d, shape=(M, N), strides=(stride_dm, stride_dn),
171+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
172+
d = d_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N])
180173
c = accumulator + d
181174

182175
offset_c = bid.to(tl.int64) * stride_cz
183-
c_block_ptr = tl.make_block_ptr(base=c_ptr + offset_c, shape=(M, N), strides=(stride_cm, stride_cn),
184-
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
185-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
186-
tl.store(c_block_ptr, c, boundary_check=(0, 1))
176+
c_desc = tl.make_tensor_descriptor(base=c_ptr + offset_c, shape=(M, N), strides=(stride_cm, stride_cn),
177+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
178+
179+
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
187180

188181

189182
# We can now create a convenience wrapper function that only takes two input tensors,
@@ -202,7 +195,7 @@ def matmul(a, b, d, c):
202195
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
203196
B,
204197
)
205-
matmul_kernel_with_block_pointers_batched[grid](
198+
matmul_kernel_with_tensor_descriptors_batched[grid](
206199
a, b, c, d, #
207200
B, M, N, K, #
208201
a.stride(0), a.stride(1), a.stride(2), #
@@ -217,7 +210,7 @@ def matmul(a, b, d, c):
217210
M, K = a.shape
218211
K, N = b.shape
219212
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
220-
matmul_kernel_with_block_pointers[grid](
213+
matmul_kernel_with_tensor_descriptors[grid](
221214
a, b, c, d, #
222215
M, N, K, #
223216
a.stride(0), a.stride(1), #

benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def gelu(x):
5353
key=['M', 'N', 'K'],
5454
)
5555
@triton.jit
56-
def matmul_kernel_with_block_pointers(
56+
def matmul_kernel_with_tensor_descriptors(
5757
# Pointers to matrices
5858
a_ptr, b_ptr, c_ptr,
5959
# Matrix dimensions
@@ -74,26 +74,23 @@ def matmul_kernel_with_block_pointers(
7474
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
7575
pid_n = (pid % num_pid_in_group) // group_size_m
7676

77-
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
78-
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
79-
order=(1, 0))
80-
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
81-
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
82-
order=(1, 0))
77+
a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
78+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
79+
b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
80+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
8381

8482
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
83+
off_k = 0
8584
for _ in range(0, K, BLOCK_SIZE_K):
86-
a = tl.load(a_block_ptr, boundary_check=(0, 1))
87-
b = tl.load(b_block_ptr, boundary_check=(0, 1))
85+
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
86+
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
8887
accumulator += tl.dot(a, b)
89-
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
90-
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
88+
off_k += BLOCK_SIZE_K
9189
c = gelu(accumulator)
9290

93-
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
94-
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
95-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
96-
tl.store(c_block_ptr, c, boundary_check=(0, 1))
91+
c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
92+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
93+
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
9794

9895

9996
# pylint: disable=unused-argument
@@ -121,7 +118,7 @@ def matmul_kernel_with_block_pointers(
121118
key=['M', 'N', 'K'],
122119
)
123120
@triton.jit
124-
def matmul_kernel_with_block_pointers_batched(
121+
def matmul_kernel_with_tensor_descriptors_batched(
125122
# Pointers to matrices
126123
a_ptr, b_ptr, c_ptr,
127124
# Matrix dimensions
@@ -146,27 +143,25 @@ def matmul_kernel_with_block_pointers_batched(
146143
offset_a = bid.to(tl.int64) * stride_az
147144
offset_b = bid.to(tl.int64) * stride_bz
148145

149-
a_block_ptr = tl.make_block_ptr(base=a_ptr + offset_a, shape=(M, K), strides=(stride_am, stride_ak),
150-
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
151-
order=(1, 0))
152-
b_block_ptr = tl.make_block_ptr(base=b_ptr + offset_b, shape=(K, N), strides=(stride_bk, stride_bn),
153-
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
154-
order=(1, 0))
146+
a_desc = tl.make_tensor_descriptor(base=a_ptr + offset_a, shape=(M, K), strides=(stride_am, stride_ak),
147+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
148+
b_desc = tl.make_tensor_descriptor(base=b_ptr + offset_b, shape=(K, N), strides=(stride_bk, stride_bn),
149+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
155150

156151
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
152+
off_k = 0
157153
for _ in range(0, K, BLOCK_SIZE_K):
158-
a = tl.load(a_block_ptr, boundary_check=(0, 1))
159-
b = tl.load(b_block_ptr, boundary_check=(0, 1))
154+
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
155+
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
160156
accumulator += tl.dot(a, b)
161-
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
162-
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
157+
off_k += BLOCK_SIZE_K
163158
c = gelu(accumulator)
164159

165160
offset_c = bid.to(tl.int64) * stride_cz
166-
c_block_ptr = tl.make_block_ptr(base=c_ptr + offset_c, shape=(M, N), strides=(stride_cm, stride_cn),
167-
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
168-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
169-
tl.store(c_block_ptr, c, boundary_check=(0, 1))
161+
c_desc = tl.make_tensor_descriptor(base=c_ptr + offset_c, shape=(M, N), strides=(stride_cm, stride_cn),
162+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
163+
164+
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
170165

171166

172167
# We can now create a convenience wrapper function that only takes two input tensors,
@@ -185,7 +180,7 @@ def matmul(a, b, c):
185180
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
186181
B,
187182
)
188-
matmul_kernel_with_block_pointers_batched[grid](
183+
matmul_kernel_with_tensor_descriptors_batched[grid](
189184
a, b, c, #
190185
B, M, N, K, #
191186
a.stride(0), a.stride(1), a.stride(2), #
@@ -198,7 +193,7 @@ def matmul(a, b, c):
198193
M, K = a.shape
199194
K, N = b.shape
200195
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
201-
matmul_kernel_with_block_pointers[grid](
196+
matmul_kernel_with_tensor_descriptors[grid](
202197
a, b, c, #
203198
M, N, K, #
204199
a.stride(0), a.stride(1), #

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""
2+
Split-K GEMM with Tensor Descriptors
3+
====================================
4+
Split-K is a approach that parallelizes the reduction dimension K to improve GPU utilization.
5+
This script implements a Split-K GEMM with tensor descriptors.
6+
"""
17
import torch
28
import triton
39
import triton.language as tl
@@ -34,27 +40,26 @@ def _kernel(A, B, C, #
3440
pid_m = group_id * GROUP_M + (pid % group_size)
3541
pid_n = (pid % width) // (group_size)
3642

37-
a_block_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak),
38-
offsets=(pid_m * BLOCK_M, pid_z * BLOCK_K), block_shape=(BLOCK_M, BLOCK_K),
39-
order=(1, 0))
40-
b_block_ptr = tl.make_block_ptr(base=B, shape=(K, N), strides=(stride_bk, stride_bn),
41-
offsets=(pid_z * BLOCK_K, pid_n * BLOCK_N), block_shape=(BLOCK_K, BLOCK_N),
42-
order=(1, 0))
43+
# Create tensor descriptors
44+
a_desc = tl.make_tensor_descriptor(base=A, shape=(M, K), strides=(stride_am, stride_ak),
45+
block_shape=(BLOCK_M, BLOCK_K))
46+
b_desc = tl.make_tensor_descriptor(base=B, shape=(K, N), strides=(stride_bk, stride_bn),
47+
block_shape=(BLOCK_K, BLOCK_N))
4348

4449
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
50+
off_k = pid_z * BLOCK_K
4551
for _ in range(0, K, BLOCK_K * SPLIT_K):
46-
a = tl.load(a_block_ptr)
47-
b = tl.load(b_block_ptr)
52+
a = a_desc.load([pid_m * BLOCK_M, off_k])
53+
b = b_desc.load([off_k, pid_n * BLOCK_N])
4854
acc += tl.dot(a, b, out_dtype=acc_dtype)
49-
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K * SPLIT_K))
50-
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K * SPLIT_K, 0))
55+
off_k += BLOCK_K * SPLIT_K
5156
acc = acc.to(C.dtype.element_ty)
57+
5258
# handles write-back with reduction-splitting
5359
if SPLIT_K == 1:
54-
c_block_ptr = tl.make_block_ptr(base=C, shape=(M, N), strides=(stride_cm, stride_cn),
55-
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), block_shape=(BLOCK_M, BLOCK_N),
56-
order=(1, 0))
57-
tl.store(c_block_ptr, acc, boundary_check=(0, 1))
60+
c_desc = tl.make_tensor_descriptor(base=C, shape=(M, N), strides=(stride_cm, stride_cn),
61+
block_shape=(BLOCK_M, BLOCK_N))
62+
c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc)
5863
else:
5964
# rematerialize rm and rn to save registers
6065
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

0 commit comments

Comments
 (0)