-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathblackwell_gemm_clc.py
More file actions
306 lines (260 loc) · 13.5 KB
/
blackwell_gemm_clc.py
File metadata and controls
306 lines (260 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
# TLX GEMM kernel optimized for Blackwell Warp Specialization
import torch
import triton
import triton.language as tl
import triton.language.extra.tlx as tlx
from triton.tools.tensor_descriptor import TensorDescriptor
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def get_cuda_autotune_config():
return [
triton.Config(
{
"BLOCK_SIZE_M": BM,
"BLOCK_SIZE_N": BN,
"BLOCK_SIZE_K": BK,
"GROUP_SIZE_M": 8,
"NUM_SMEM_BUFFERS": s,
"NUM_TMEM_BUFFERS": t,
"EPILOGUE_SUBTILE": subtile,
},
num_warps=4,
num_stages=1,
pre_hook=matmul_tma_set_block_size_hook,
)
for BM in [128]
for BN in [128, 256]
for BK in [64, 128]
for s in [2, 3, 4]
for t in [2, 3]
for subtile in [True]
]
def matmul_tma_set_block_size_hook(nargs):
BLOCK_M = nargs["BLOCK_SIZE_M"]
BLOCK_N = nargs["BLOCK_SIZE_N"]
BLOCK_K = nargs["BLOCK_SIZE_K"]
nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
EPILOGUE_SUBTILE = nargs.get("EPILOGUE_SUBTILE", False)
if EPILOGUE_SUBTILE:
nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N // 2]
else:
nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n
@triton.autotune(
configs=get_cuda_autotune_config(),
key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel_tma_ws_blackwell_clc(a_desc, b_desc, c_desc, M, N, K, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
NUM_SMEM_BUFFERS: tl.constexpr, #
NUM_TMEM_BUFFERS: tl.constexpr, #
NUM_SMS: tl.constexpr, #
NUM_CLC_STAGES: tl.constexpr, #
EPILOGUE_SUBTILE: tl.constexpr, #
USE_WARP_BARRIER: tl.constexpr = False, #
):
# allocate NUM_SMEM_BUFFERS buffers
buffers_A = tlx.local_alloc((BLOCK_SIZE_M, BLOCK_SIZE_K), tlx.dtype_of(a_desc), NUM_SMEM_BUFFERS)
buffers_B = tlx.local_alloc((BLOCK_SIZE_K, BLOCK_SIZE_N), tlx.dtype_of(b_desc), NUM_SMEM_BUFFERS)
# use multiple TMEM buffers to overlap MMA and epilogue
tmem_buffers = tlx.local_alloc((BLOCK_SIZE_M, BLOCK_SIZE_N), tl.float32, NUM_TMEM_BUFFERS, tlx.storage_kind.tmem)
# allocate barriers
smem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1)
smem_full_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1)
if USE_WARP_BARRIER:
tmem_full_bars = tlx.alloc_warp_barrier(num_barriers=NUM_TMEM_BUFFERS, num_warps=1)
tmem_empty_bars = tlx.alloc_warp_barrier(num_barriers=NUM_TMEM_BUFFERS, num_warps=4)
else:
tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1)
tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1)
clc_context = tlx.clc_create_context(num_consumers=3)
with tlx.async_tasks():
with tlx.async_task("default"): # epilogue consumer
# common code duplicated for each region to avoid SMEM overhead
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
# end of common code
tmem_read_phase = 0
cur_tmem_buf = 0
tile_id = start_pid
clc_phase_producer = 1
clc_phase_consumer = 0
while tile_id != -1:
# Debug prints
# if tlx.thread_id(axis=0) == 0:
# tl.device_print("Default WG Processing CtaID", tile_id)
# producer
tlx.clc_producer(clc_context, clc_phase_producer)
clc_phase_producer ^= 1
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
tlx.barrier_wait(tmem_full_bars[cur_tmem_buf], tmem_read_phase)
# flip phase at the end of a round of using TMEM barriers
tmem_read_phase = tmem_read_phase ^ (cur_tmem_buf == NUM_TMEM_BUFFERS - 1)
# load the result from TMEM to registers
acc_tmem = tmem_buffers[cur_tmem_buf]
if EPILOGUE_SUBTILE:
# We load/store the result half by half to reduce SMEM pressure
acc_tmem_subslice1 = tlx.subslice(acc_tmem, 0, BLOCK_SIZE_N // 2)
result = tlx.local_load(acc_tmem_subslice1)
c = result.to(tlx.dtype_of(c_desc))
c_desc.store([offs_am, offs_bn], c)
acc_tmem_subslice2 = tlx.subslice(acc_tmem, BLOCK_SIZE_N // 2, BLOCK_SIZE_N // 2)
result = tlx.local_load(acc_tmem_subslice2)
c = result.to(tlx.dtype_of(c_desc))
c_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], c)
else:
result = tlx.local_load(acc_tmem)
c = result.to(tlx.dtype_of(c_desc))
c_desc.store([offs_am, offs_bn], c)
# done storing this buffer, signal MMA consumer to resume writing to it
tlx.barrier_arrive(tmem_empty_bars[cur_tmem_buf], 1)
cur_tmem_buf = (cur_tmem_buf + 1) % NUM_TMEM_BUFFERS
tile_id = tlx.clc_consumer(clc_context, clc_phase_consumer)
clc_phase_consumer ^= 1
# Debug-only: verifying that CLC steals workloads successfully
# if tlx.thread_id(axis=0) == 0:
# tl.device_print("Extracted CtaID", tile_id)
with tlx.async_task(num_warps=1, num_regs=232): # MMA consumer
# common code duplicated for each region to avoid SMEM overhead
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
# end of common code
dot_phase = 0 # the current phase of dot op
tmem_write_phase = 1 # sync between epilogue consumer and MMA consumer
cur_tmem_buf = 0
processed_k_iters = 0
tile_id = start_pid
clc_phase_consumer = 0
while tile_id != -1:
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
# wait epilogue consumer to be done with the buffer before reusing it
tlx.barrier_wait(tmem_empty_bars[cur_tmem_buf], tmem_write_phase)
# flip phase at the end of a round of using TMEM barriers
tmem_write_phase = tmem_write_phase ^ (cur_tmem_buf == NUM_TMEM_BUFFERS - 1)
# now iterate along K to compute result for the block
for k in range(0, k_tiles):
# processed_k_iters + k means we use the immediate next buffer slot of tile_id x when we start tile_id x+1
buf = (processed_k_iters + k) % NUM_SMEM_BUFFERS
# wait for current phase(round) of load for this buf
tlx.barrier_wait(smem_full_bars[buf], dot_phase)
# buffer is now ready with loaded data, tlx.async_dot will signal `mBarrier` when done
tlx.async_dot(buffers_A[buf], buffers_B[buf], tmem_buffers[cur_tmem_buf], use_acc=k > 0,
mBarriers=[smem_empty_bars[buf]], out_dtype=tl.float32)
# flip phase at the end of a round
dot_phase = dot_phase ^ (buf == NUM_SMEM_BUFFERS - 1)
# wait for last mma to complete
last_buf = (processed_k_iters + k_tiles - 1) % NUM_SMEM_BUFFERS
# in case phase was flipped, we should use the phase value when dot op was issued
last_dot_phase = dot_phase ^ (last_buf == NUM_SMEM_BUFFERS - 1)
tlx.barrier_wait(smem_empty_bars[last_buf], last_dot_phase)
# done filling this buffer, signal epilogue consumer
tlx.barrier_arrive(tmem_full_bars[cur_tmem_buf], 1)
# possibly enter next iteration (next tile) without waiting for epilogue
cur_tmem_buf = (cur_tmem_buf + 1) % NUM_TMEM_BUFFERS
processed_k_iters += k_tiles
tile_id = tlx.clc_consumer(clc_context, clc_phase_consumer)
clc_phase_consumer ^= 1
with tlx.async_task(num_warps=1, num_regs=232): # producer, TMA load
# common code duplicated for each region to avoid SMEM overhead
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
# end of common code
load_phase = 0 # the current phase of TMA load
# we virtually "flatten" the two layer loop as if we're performing tma loads on
# one big list of data
processed_k_iters = 0
tile_id = start_pid
clc_phase_consumer = 0
while tile_id != -1:
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
for k in range(0, k_tiles):
# processed_k_iters + k means we use the immediate next buffer slot of tile_id x when we start tile_id x+1
buf = (processed_k_iters + k) % NUM_SMEM_BUFFERS
# wait for previous phase(round) of dot for this buf
tlx.barrier_wait(smem_empty_bars[buf], load_phase ^ 1)
# buffer is now ready to be used again
offs_k = k * BLOCK_SIZE_K
tlx.barrier_expect_bytes(smem_full_bars[buf],
2 * (BLOCK_SIZE_M + BLOCK_SIZE_N) * BLOCK_SIZE_K) # float16
tlx.async_descriptor_load(a_desc, buffers_A[buf], [offs_am, offs_k], smem_full_bars[buf])
tlx.async_descriptor_load(b_desc, buffers_B[buf], [offs_k, offs_bn], smem_full_bars[buf])
# flip phase at the end of a round
load_phase = load_phase ^ (buf == NUM_SMEM_BUFFERS - 1)
processed_k_iters += k_tiles
tile_id = tlx.clc_consumer(clc_context, clc_phase_consumer)
clc_phase_consumer ^= 1
def matmul(a, b, config=None, use_warp_barrier=False):
"""Matrix multiplication using TLX GEMM kernel."""
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# A dummy block value that will be overwritten when we have the real block size
dummy_block = [1, 1]
a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
if config is not None:
a_desc.block_shape = [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_K"]]
b_desc.block_shape = [config["BLOCK_SIZE_K"], config["BLOCK_SIZE_N"]]
if config.get("EPILOGUE_SUBTILE", False):
c_desc.block_shape = [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"] // 2]
else:
c_desc.block_shape = [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"]]
grid = (triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(N, config["BLOCK_SIZE_N"]), )
matmul_kernel_tma_ws_blackwell_clc.fn[grid](
a_desc,
b_desc,
c_desc,
M,
N,
K,
NUM_SMS=NUM_SMS,
NUM_CLC_STAGES=1,
USE_WARP_BARRIER=use_warp_barrier,
**config,
)
else:
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
matmul_kernel_tma_ws_blackwell_clc[grid](
a_desc,
b_desc,
c_desc,
M,
N,
K,
NUM_SMS=NUM_SMS,
NUM_CLC_STAGES=1,
USE_WARP_BARRIER=use_warp_barrier,
)
return c
def matmul_warp_barrier(a, b, config=None):
return matmul(a, b, config=config, use_warp_barrier=True)