|
| 1 | +# TLX GEMM kernel optimized for Blackwell Warp Specialization |
| 2 | +import torch |
| 3 | + |
| 4 | +import triton |
| 5 | +import triton.language as tl |
| 6 | +import triton.language.extra.tlx as tlx |
| 7 | +from triton.tools.tensor_descriptor import TensorDescriptor |
| 8 | + |
| 9 | + |
| 10 | +def get_cuda_autotune_config(): |
| 11 | + return [ |
| 12 | + triton.Config( |
| 13 | + { |
| 14 | + "BLOCK_SIZE_M": BM, |
| 15 | + "BLOCK_SIZE_N": BN, |
| 16 | + "BLOCK_SIZE_K": BK, |
| 17 | + "GROUP_SIZE_M": 8, |
| 18 | + "NUM_SMEM_BUFFERS": s, |
| 19 | + "NUM_TMEM_BUFFERS": t, |
| 20 | + "EPILOGUE_SUBTILE": subtile, |
| 21 | + }, |
| 22 | + num_warps=4, |
| 23 | + num_stages=1, |
| 24 | + pre_hook=matmul_tma_set_block_size_hook, |
| 25 | + ) |
| 26 | + for BM in [128] |
| 27 | + for BN in [128, 256] |
| 28 | + for BK in [64, 128] |
| 29 | + for s in [2, 3, 4] |
| 30 | + for t in [2, 3] |
| 31 | + for subtile in [True] |
| 32 | + ] |
| 33 | + |
| 34 | + |
| 35 | +def matmul_tma_set_block_size_hook(nargs): |
| 36 | + BLOCK_M = nargs["BLOCK_SIZE_M"] |
| 37 | + BLOCK_N = nargs["BLOCK_SIZE_N"] |
| 38 | + BLOCK_K = nargs["BLOCK_SIZE_K"] |
| 39 | + nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] |
| 40 | + nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] |
| 41 | + EPILOGUE_SUBTILE = nargs.get("EPILOGUE_SUBTILE", False) |
| 42 | + if EPILOGUE_SUBTILE: |
| 43 | + nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N // 2] |
| 44 | + else: |
| 45 | + nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] |
| 46 | + |
| 47 | + |
| 48 | +@triton.jit |
| 49 | +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M): |
| 50 | + group_id = tile_id // num_pid_in_group |
| 51 | + first_pid_m = group_id * GROUP_SIZE_M |
| 52 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 53 | + pid_m = first_pid_m + (tile_id % group_size_m) |
| 54 | + pid_n = (tile_id % num_pid_in_group) // group_size_m |
| 55 | + return pid_m, pid_n |
| 56 | + |
| 57 | + |
| 58 | +@triton.autotune( |
| 59 | + configs=get_cuda_autotune_config(), |
| 60 | + key=["M", "N", "K"], |
| 61 | +) |
| 62 | +@triton.jit |
| 63 | +def matmul_kernel_tma_ws_blackwell( |
| 64 | + a_desc, |
| 65 | + b_desc, |
| 66 | + c_desc, |
| 67 | + M, |
| 68 | + N, |
| 69 | + K, |
| 70 | + BLOCK_SIZE_M: tl.constexpr, |
| 71 | + BLOCK_SIZE_N: tl.constexpr, |
| 72 | + BLOCK_SIZE_K: tl.constexpr, # |
| 73 | + GROUP_SIZE_M: tl.constexpr, # |
| 74 | + NUM_SMEM_BUFFERS: tl.constexpr, # |
| 75 | + NUM_TMEM_BUFFERS: tl.constexpr, # |
| 76 | + NUM_SMS: tl.constexpr, # |
| 77 | + EPILOGUE_SUBTILE: tl.constexpr, # |
| 78 | +): |
| 79 | + # allocate NUM_SMEM_BUFFERS buffers |
| 80 | + buffers_A = tlx.local_alloc( |
| 81 | + (BLOCK_SIZE_M, BLOCK_SIZE_K), tl.float16, NUM_SMEM_BUFFERS |
| 82 | + ) |
| 83 | + buffers_B = tlx.local_alloc( |
| 84 | + (BLOCK_SIZE_K, BLOCK_SIZE_N), tl.float16, NUM_SMEM_BUFFERS |
| 85 | + ) |
| 86 | + # use multiple TMEM buffers to overlap MMA and epilogue |
| 87 | + tmem_buffers = tlx.local_alloc( |
| 88 | + (BLOCK_SIZE_M, BLOCK_SIZE_N), |
| 89 | + tl.float32, |
| 90 | + NUM_TMEM_BUFFERS, |
| 91 | + tlx.storage_kind.tmem, |
| 92 | + ) |
| 93 | + |
| 94 | + # allocate barriers |
| 95 | + smem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) |
| 96 | + smem_full_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) |
| 97 | + tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) |
| 98 | + tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) |
| 99 | + |
| 100 | + with tlx.async_tasks(): |
| 101 | + with tlx.async_task("default"): # producer, TMA load |
| 102 | + # common code duplicated for each region to avoid SMEM overhead |
| 103 | + start_pid = tl.program_id(axis=0) |
| 104 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 105 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 106 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 107 | + num_tiles = num_pid_m * num_pid_n |
| 108 | + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) |
| 109 | + # end of common code |
| 110 | + |
| 111 | + load_phase = 0 # the current phase of TMA load |
| 112 | + # we virtually "flatten" the two layer loop as if we're performing tma loads on |
| 113 | + # one big list of data |
| 114 | + processed_k_iters = 0 |
| 115 | + for tile_id in range(start_pid, num_tiles, NUM_SMS): |
| 116 | + pid_m, pid_n = _compute_pid( |
| 117 | + tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M |
| 118 | + ) |
| 119 | + offs_am = pid_m * BLOCK_SIZE_M |
| 120 | + offs_bn = pid_n * BLOCK_SIZE_N |
| 121 | + |
| 122 | + for k in range(0, k_tiles): |
| 123 | + # processed_k_iters + k means we use the immediate next buffer slot of tile_id x when we start tile_id x+1 |
| 124 | + buf = (processed_k_iters + k) % NUM_SMEM_BUFFERS |
| 125 | + # wait for previous phase(round) of dot for this buf |
| 126 | + tlx.barrier_wait(smem_empty_bars[buf], load_phase ^ 1) |
| 127 | + # buffer is now ready to be used again |
| 128 | + offs_k = k * BLOCK_SIZE_K |
| 129 | + tlx.barrier_expect_bytes( |
| 130 | + smem_full_bars[buf], |
| 131 | + 2 * (BLOCK_SIZE_M + BLOCK_SIZE_N) * BLOCK_SIZE_K, |
| 132 | + ) # float16 |
| 133 | + tlx.async_descriptor_load( |
| 134 | + a_desc, buffers_A[buf], [offs_am, offs_k], smem_full_bars[buf] |
| 135 | + ) |
| 136 | + tlx.async_descriptor_load( |
| 137 | + b_desc, buffers_B[buf], [offs_k, offs_bn], smem_full_bars[buf] |
| 138 | + ) |
| 139 | + # flip phase at the end of a round |
| 140 | + load_phase = load_phase ^ (buf == NUM_SMEM_BUFFERS - 1) |
| 141 | + processed_k_iters += k_tiles |
| 142 | + with tlx.async_task(num_warps=1, num_regs=232): # MMA consumer |
| 143 | + # common code duplicated for each region to avoid SMEM overhead |
| 144 | + start_pid = tl.program_id(axis=0) |
| 145 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 146 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 147 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 148 | + num_tiles = num_pid_m * num_pid_n |
| 149 | + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) |
| 150 | + # end of common code |
| 151 | + |
| 152 | + dot_phase = 0 # the current phase of dot op |
| 153 | + tmem_write_phase = 1 # sync between epilogue consumer and MMA consumer |
| 154 | + cur_tmem_buf = 0 |
| 155 | + |
| 156 | + processed_k_iters = 0 |
| 157 | + for tile_id in range(start_pid, num_tiles, NUM_SMS): |
| 158 | + pid_m, pid_n = _compute_pid( |
| 159 | + tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M |
| 160 | + ) |
| 161 | + offs_am = pid_m * BLOCK_SIZE_M |
| 162 | + offs_bn = pid_n * BLOCK_SIZE_N |
| 163 | + |
| 164 | + # wait epilogue consumer to be done with the buffer before reusing it |
| 165 | + tlx.barrier_wait(tmem_empty_bars[cur_tmem_buf], tmem_write_phase) |
| 166 | + # flip phase at the end of a round of using TMEM barriers |
| 167 | + tmem_write_phase = tmem_write_phase ^ ( |
| 168 | + cur_tmem_buf == NUM_TMEM_BUFFERS - 1 |
| 169 | + ) |
| 170 | + |
| 171 | + # now iterate along K to compute result for the block |
| 172 | + for k in range(0, k_tiles): |
| 173 | + # processed_k_iters + k means we use the immediate next buffer slot of tile_id x when we start tile_id x+1 |
| 174 | + buf = (processed_k_iters + k) % NUM_SMEM_BUFFERS |
| 175 | + # wait for current phase(round) of load for this buf |
| 176 | + tlx.barrier_wait(smem_full_bars[buf], dot_phase) |
| 177 | + # buffer is now ready with loaded data, tlx.async_dot will signal `mBarrier` when done |
| 178 | + tlx.async_dot( |
| 179 | + buffers_A[buf], |
| 180 | + buffers_B[buf], |
| 181 | + tmem_buffers[cur_tmem_buf], |
| 182 | + use_acc=k > 0, |
| 183 | + mBarriers=[smem_empty_bars[buf]], |
| 184 | + out_dtype=tl.float32, |
| 185 | + ) |
| 186 | + # flip phase at the end of a round |
| 187 | + dot_phase = dot_phase ^ (buf == NUM_SMEM_BUFFERS - 1) |
| 188 | + |
| 189 | + # wait for last mma to complete |
| 190 | + last_buf = (processed_k_iters + k_tiles - 1) % NUM_SMEM_BUFFERS |
| 191 | + # in case phase was flipped, we should use the phase value when dot op was issued |
| 192 | + last_dot_phase = dot_phase ^ (last_buf == NUM_SMEM_BUFFERS - 1) |
| 193 | + tlx.barrier_wait(smem_empty_bars[last_buf], last_dot_phase) |
| 194 | + |
| 195 | + # done filling this buffer, signal epilogue consumer |
| 196 | + tlx.barrier_arrive(tmem_full_bars[cur_tmem_buf], 1) |
| 197 | + |
| 198 | + # possibly enter next iteration (next tile) without waiting for epilogue |
| 199 | + cur_tmem_buf = (cur_tmem_buf + 1) % NUM_TMEM_BUFFERS |
| 200 | + processed_k_iters += k_tiles |
| 201 | + |
| 202 | + with tlx.async_task(num_warps=4, num_regs=232): # epilogue consumer |
| 203 | + # common code duplicated for each region to avoid SMEM overhead |
| 204 | + start_pid = tl.program_id(axis=0) |
| 205 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 206 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 207 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 208 | + num_tiles = num_pid_m * num_pid_n |
| 209 | + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) |
| 210 | + # end of common code |
| 211 | + |
| 212 | + tmem_read_phase = 0 |
| 213 | + cur_tmem_buf = 0 |
| 214 | + |
| 215 | + for tile_id in range(start_pid, num_tiles, NUM_SMS): |
| 216 | + pid_m, pid_n = _compute_pid( |
| 217 | + tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M |
| 218 | + ) |
| 219 | + offs_am = pid_m * BLOCK_SIZE_M |
| 220 | + offs_bn = pid_n * BLOCK_SIZE_N |
| 221 | + |
| 222 | + tlx.barrier_wait(tmem_full_bars[cur_tmem_buf], tmem_read_phase) |
| 223 | + # flip phase at the end of a round of using TMEM barriers |
| 224 | + tmem_read_phase = tmem_read_phase ^ ( |
| 225 | + cur_tmem_buf == NUM_TMEM_BUFFERS - 1 |
| 226 | + ) |
| 227 | + |
| 228 | + # load the result from TMEM to registers |
| 229 | + acc_tmem = tmem_buffers[cur_tmem_buf] |
| 230 | + |
| 231 | + if EPILOGUE_SUBTILE: |
| 232 | + # We load/store the result half by half to reduce SMEM pressure |
| 233 | + acc_tmem_subslice1 = tlx.subslice(acc_tmem, 0, BLOCK_SIZE_N // 2) |
| 234 | + result = tlx.local_load(acc_tmem_subslice1) |
| 235 | + c = result.to(tl.float16) |
| 236 | + c_desc.store([offs_am, offs_bn], c) |
| 237 | + |
| 238 | + acc_tmem_subslice2 = tlx.subslice( |
| 239 | + acc_tmem, BLOCK_SIZE_N // 2, BLOCK_SIZE_N // 2 |
| 240 | + ) |
| 241 | + result = tlx.local_load(acc_tmem_subslice2) |
| 242 | + c = result.to(tl.float16) |
| 243 | + c_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], c) |
| 244 | + else: |
| 245 | + result = tlx.local_load(acc_tmem) |
| 246 | + c = result.to(tl.float16) |
| 247 | + c_desc.store([offs_am, offs_bn], c) |
| 248 | + |
| 249 | + # done storing this buffer, signal MMA consumer to resume writing to it |
| 250 | + tlx.barrier_arrive(tmem_empty_bars[cur_tmem_buf], 1) |
| 251 | + |
| 252 | + cur_tmem_buf = (cur_tmem_buf + 1) % NUM_TMEM_BUFFERS |
| 253 | + |
| 254 | + |
| 255 | +def tlx_matmul(a, b): |
| 256 | + # Check constraints. |
| 257 | + assert a.shape[1] == b.shape[0], "Incompatible dimensions" |
| 258 | + assert a.is_contiguous(), "Matrix A must be contiguous" |
| 259 | + M, K = a.shape |
| 260 | + K, N = b.shape |
| 261 | + # Allocates output. |
| 262 | + c = torch.empty((M, N), device=a.device, dtype=torch.float16) |
| 263 | + |
| 264 | + # A dummy block value that will be overwritten when we have the real block size |
| 265 | + dummy_block = [1, 1] |
| 266 | + a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) |
| 267 | + b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) |
| 268 | + c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) |
| 269 | + |
| 270 | + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count |
| 271 | + |
| 272 | + # Persistent kernel to have thread block resident in SM as long as possible |
| 273 | + grid = lambda META: ( |
| 274 | + min( |
| 275 | + NUM_SMS, |
| 276 | + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), |
| 277 | + ), |
| 278 | + ) |
| 279 | + matmul_kernel_tma_ws_blackwell[grid]( |
| 280 | + a_desc, |
| 281 | + b_desc, |
| 282 | + c_desc, # |
| 283 | + M, |
| 284 | + N, |
| 285 | + K, # |
| 286 | + NUM_SMS=NUM_SMS, # |
| 287 | + ) |
| 288 | + return c |
0 commit comments