Skip to content

Commit 67f901d

Browse files
Use tensor descriptor in gemm_preop_exp_benchmark (#5175)
Fixes #4321 --------- Co-authored-by: He Dan <[email protected]>
1 parent 922ff0f commit 67f901d

File tree

1 file changed

+26
-32
lines changed

1 file changed

+26
-32
lines changed

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
key=['M', 'N', 'K'],
3636
)
3737
@triton.jit
38-
def matmul_kernel_with_block_pointers(
38+
def matmul_kernel_with_tensor_descriptors(
3939
# Pointers to matrices
4040
a_ptr, b_ptr, c_ptr,
4141
# Matrix dimensions
@@ -56,29 +56,26 @@ def matmul_kernel_with_block_pointers(
5656
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
5757
pid_n = (pid % num_pid_in_group) // group_size_m
5858

59-
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
60-
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
61-
order=(1, 0))
62-
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
63-
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
64-
order=(1, 0))
59+
a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
60+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
61+
b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
62+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
6563

6664
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
65+
off_k = 0
6766
for _ in range(0, K, BLOCK_SIZE_K):
68-
a = tl.load(a_block_ptr, boundary_check=(0, 1))
67+
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
6968
a = a.to(tl.float32)
7069
a = tl.math.exp(a)
7170
a = a.to(tl.bfloat16)
72-
b = tl.load(b_block_ptr, boundary_check=(0, 1))
71+
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
7372
accumulator += tl.dot(a, b)
74-
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
75-
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
73+
off_k += BLOCK_SIZE_K
7674
c = accumulator.to(tl.float32)
7775

78-
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
79-
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
80-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
81-
tl.store(c_block_ptr, c, boundary_check=(0, 1))
76+
c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
77+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
78+
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
8279

8380

8481
# pylint: disable=unused-argument
@@ -106,7 +103,7 @@ def matmul_kernel_with_block_pointers(
106103
key=['M', 'N', 'K'],
107104
)
108105
@triton.jit
109-
def matmul_kernel_with_block_pointers_batched(
106+
def matmul_kernel_with_tensor_descriptors_batched(
110107
# Pointers to matrices
111108
a_ptr, b_ptr, c_ptr,
112109
# Matrix dimensions
@@ -131,30 +128,27 @@ def matmul_kernel_with_block_pointers_batched(
131128
offset_a = bid.to(tl.int64) * stride_az
132129
offset_b = bid.to(tl.int64) * stride_bz
133130

134-
a_block_ptr = tl.make_block_ptr(base=a_ptr + offset_a, shape=(M, K), strides=(stride_am, stride_ak),
135-
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
136-
order=(1, 0))
137-
b_block_ptr = tl.make_block_ptr(base=b_ptr + offset_b, shape=(K, N), strides=(stride_bk, stride_bn),
138-
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
139-
order=(1, 0))
131+
a_desc = tl.make_tensor_descriptor(base=a_ptr + offset_a, shape=(M, K), strides=(stride_am, stride_ak),
132+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
133+
b_desc = tl.make_tensor_descriptor(base=b_ptr + offset_b, shape=(K, N), strides=(stride_bk, stride_bn),
134+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
140135

141136
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
137+
off_k = 0
142138
for _ in range(0, K, BLOCK_SIZE_K):
143-
a = tl.load(a_block_ptr, boundary_check=(0, 1))
139+
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
144140
a = a.to(tl.float32)
145141
a = tl.math.exp(a)
146142
a = a.to(tl.bfloat16)
147-
b = tl.load(b_block_ptr, boundary_check=(0, 1))
143+
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
148144
accumulator += tl.dot(a, b)
149-
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
150-
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
145+
off_k += BLOCK_SIZE_K
151146
c = accumulator.to(tl.float32)
152147

153148
offset_c = bid.to(tl.int64) * stride_cz
154-
c_block_ptr = tl.make_block_ptr(base=c_ptr + offset_c, shape=(M, N), strides=(stride_cm, stride_cn),
155-
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
156-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
157-
tl.store(c_block_ptr, c, boundary_check=(0, 1))
149+
c_desc = tl.make_tensor_descriptor(base=c_ptr + offset_c, shape=(M, N), strides=(stride_cm, stride_cn),
150+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
151+
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
158152

159153

160154
# We can now create a convenience wrapper function that only takes two input tensors,
@@ -173,7 +167,7 @@ def matmul(a, b, c):
173167
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
174168
B,
175169
)
176-
matmul_kernel_with_block_pointers_batched[grid](
170+
matmul_kernel_with_tensor_descriptors_batched[grid](
177171
a, b, c, #
178172
B, M, N, K, #
179173
a.stride(0), a.stride(1), a.stride(2), #
@@ -186,7 +180,7 @@ def matmul(a, b, c):
186180
M, K = a.shape
187181
K, N = b.shape
188182
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
189-
matmul_kernel_with_block_pointers[grid](
183+
matmul_kernel_with_tensor_descriptors[grid](
190184
a, b, c, #
191185
M, N, K, #
192186
a.stride(0), a.stride(1), #

0 commit comments

Comments
 (0)