|
| 1 | +import torch |
| 2 | + |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | +import triton.language.extra.tlx as tlx |
| 6 | +from triton.tools.tensor_descriptor import TensorDescriptor |
| 7 | + |
| 8 | +DEVICE = triton.runtime.driver.active.get_active_torch_device() |
| 9 | + |
| 10 | + |
| 11 | +def _host_descriptor_pre_hook(nargs): |
| 12 | + BLOCK_M = nargs["BLOCK_M"] |
| 13 | + BLOCK_N = nargs["BLOCK_N"] |
| 14 | + HEAD_DIM = nargs["HEAD_DIM"] |
| 15 | + if not isinstance(nargs["desc_q"], TensorDescriptor): |
| 16 | + return |
| 17 | + HEAD_DIM = nargs["HEAD_DIM"] |
| 18 | + NUM_MMA_GROUPS = nargs["NUM_MMA_GROUPS"] |
| 19 | + BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS |
| 20 | + nargs["desc_q"].block_shape = [BLOCK_M_SPLIT, HEAD_DIM] |
| 21 | + if nargs["FP8_OUTPUT"]: |
| 22 | + nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N] |
| 23 | + else: |
| 24 | + nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] |
| 25 | + nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] |
| 26 | + nargs["desc_o"].block_shape = [BLOCK_M_SPLIT, HEAD_DIM] |
| 27 | + |
| 28 | + |
| 29 | +configs = [ |
| 30 | + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'NUM_BUFFERS': 2, 'NUM_MMA_WARPS': 8, 'NUM_MMA_GROUPS': 2}, |
| 31 | + num_stages=0, num_warps=4, pre_hook=_host_descriptor_pre_hook), |
| 32 | +] |
| 33 | + |
| 34 | + |
| 35 | +@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"]) |
| 36 | +@triton.jit |
| 37 | +def _attn_fwd_ws_pipelined_pingpong(sm_scale, M, # |
| 38 | + Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # |
| 39 | + HEAD_DIM: tl.constexpr, # |
| 40 | + BLOCK_M: tl.constexpr, # |
| 41 | + BLOCK_N: tl.constexpr, # |
| 42 | + FP8_OUTPUT: tl.constexpr, # |
| 43 | + NUM_BUFFERS: tl.constexpr, # |
| 44 | + NUM_MMA_WARPS: tl.constexpr, # |
| 45 | + NUM_MMA_GROUPS: tl.constexpr, # |
| 46 | + ): |
| 47 | + tl.static_assert(BLOCK_N <= HEAD_DIM) |
| 48 | + BLOCK_M_SPLIT: tl.constexpr = BLOCK_M // NUM_MMA_GROUPS |
| 49 | + |
| 50 | + # allocate buffers |
| 51 | + q_tiles = tlx.local_alloc((BLOCK_M_SPLIT, HEAD_DIM), tlx.dtype_of(desc_q), NUM_MMA_GROUPS) |
| 52 | + k_tiles = tlx.local_alloc((BLOCK_N, HEAD_DIM), tlx.dtype_of(desc_k), NUM_BUFFERS) |
| 53 | + v_tiles = tlx.local_alloc((BLOCK_N, HEAD_DIM), tlx.dtype_of(desc_v), NUM_BUFFERS) |
| 54 | + |
| 55 | + # allocate barriers |
| 56 | + q_fulls = tlx.alloc_barriers(num_barriers=NUM_MMA_GROUPS, arrive_count=1) |
| 57 | + k_empties = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS) |
| 58 | + k_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1) |
| 59 | + v_empties = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS) |
| 60 | + v_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1) |
| 61 | + |
| 62 | + with tlx.async_tasks(): |
| 63 | + # producer group |
| 64 | + with tlx.async_task("default"): |
| 65 | + # initialize offsets |
| 66 | + start_m = tl.program_id(0) |
| 67 | + off_hz = tl.program_id(1) |
| 68 | + off_z = off_hz // H |
| 69 | + off_h = off_hz % H |
| 70 | + offset_y = off_z * (N_CTX * H) + off_h * N_CTX |
| 71 | + qo_offset_y = offset_y + start_m * BLOCK_M |
| 72 | + lo, hi = 0, N_CTX |
| 73 | + kv_offset_y = offset_y + lo |
| 74 | + |
| 75 | + # load q: it will stay in SRAM throughout |
| 76 | + for cid in tl.range(0, NUM_MMA_GROUPS, loop_unroll_factor=NUM_MMA_GROUPS): |
| 77 | + q_full = tlx.local_view(q_fulls, cid) |
| 78 | + tlx.barrier_expect_bytes(q_full, 2 * BLOCK_M_SPLIT * HEAD_DIM) # float16 |
| 79 | + q_tile = tlx.local_view(q_tiles, cid) |
| 80 | + qo_offset_y_split = qo_offset_y + cid * BLOCK_M_SPLIT |
| 81 | + tlx.async_descriptor_load(desc_q, q_tile, [qo_offset_y_split, 0], q_full) |
| 82 | + |
| 83 | + # loop over loading k, v |
| 84 | + kv_phase = 0 |
| 85 | + acc_cnt = 0 |
| 86 | + for _ in tl.range(lo, hi, BLOCK_N): |
| 87 | + buf_id = acc_cnt % NUM_BUFFERS |
| 88 | + # buffers in a row share the same phase |
| 89 | + kv_phase = kv_phase ^ (buf_id == 0) |
| 90 | + |
| 91 | + # wait for the K buffer to be released by the consumer |
| 92 | + k_empty = tlx.local_view(k_empties, buf_id) |
| 93 | + tlx.barrier_wait(k_empty, kv_phase) |
| 94 | + # load K |
| 95 | + k_full = tlx.local_view(k_fulls, buf_id) |
| 96 | + k_tile = tlx.local_view(k_tiles, buf_id) |
| 97 | + tlx.barrier_expect_bytes(k_full, 2 * BLOCK_N * HEAD_DIM) # float16 |
| 98 | + tlx.async_descriptor_load(desc_k, k_tile, [kv_offset_y, 0], k_full) |
| 99 | + |
| 100 | + # wait for the V buffer to be released by the consumer |
| 101 | + v_empty = tlx.local_view(v_empties, buf_id) |
| 102 | + tlx.barrier_wait(v_empty, kv_phase) |
| 103 | + # load V |
| 104 | + v_full = tlx.local_view(v_fulls, buf_id) |
| 105 | + v_tile = tlx.local_view(v_tiles, buf_id) |
| 106 | + tlx.barrier_expect_bytes(v_full, 2 * BLOCK_N * HEAD_DIM) # float16 |
| 107 | + tlx.async_descriptor_load(desc_v, v_tile, [kv_offset_y, 0], v_full) |
| 108 | + |
| 109 | + kv_offset_y += BLOCK_N |
| 110 | + acc_cnt += 1 |
| 111 | + |
| 112 | + # consumer group |
| 113 | + with tlx.async_task(num_warps=NUM_MMA_WARPS // NUM_MMA_GROUPS, registers=232, replicate=NUM_MMA_GROUPS): |
| 114 | + # initialize pointer to m and l |
| 115 | + m_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) - float("inf") |
| 116 | + l_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) + 1.0 |
| 117 | + acc = tl.zeros([BLOCK_M_SPLIT, HEAD_DIM], dtype=tl.float32) |
| 118 | + |
| 119 | + # load scales |
| 120 | + qk_scale = sm_scale |
| 121 | + qk_scale *= 1.44269504 # 1/log(2) |
| 122 | + |
| 123 | + # wait for the Q buffer to be populated by the producer |
| 124 | + cid: tl.constexpr = tlx.async_task_replica_id() |
| 125 | + q_full = tlx.local_view(q_fulls, cid) |
| 126 | + tlx.barrier_wait(q_full, 0) |
| 127 | + q_tile = tlx.local_view(q_tiles, cid) |
| 128 | + |
| 129 | + lo, hi = 0, N_CTX |
| 130 | + k_phase = 0 |
| 131 | + v_phase = 1 |
| 132 | + k_buf_id = 0 |
| 133 | + v_buf_id = 0 |
| 134 | + |
| 135 | + # wait for the K[0] buffer to be populated by the producer |
| 136 | + k_full = tlx.local_view(k_fulls, k_buf_id) |
| 137 | + tlx.barrier_wait(k_full, k_phase) |
| 138 | + k_tile = tlx.local_view(k_tiles, k_buf_id) |
| 139 | + |
| 140 | + # -- compute qk[0] ---- |
| 141 | + k_tile = tlx.local_trans(k_tile) |
| 142 | + |
| 143 | + if cid == 0: |
| 144 | + # Consumer 0 waits for Consumer 1 to reach synchronization point at barrier 9. |
| 145 | + tlx.named_barrier_wait(9, 256) |
| 146 | + else: |
| 147 | + # Consumer 1 signals its arrival at barrier 9. |
| 148 | + tlx.named_barrier_arrive(9, 256) |
| 149 | + # Then waits at barrier 10 until Consumer 0 finishes issuing its async_dot. |
| 150 | + tlx.named_barrier_wait(10, 256) |
| 151 | + |
| 152 | + qk = tlx.async_dot(q_tile, k_tile) |
| 153 | + |
| 154 | + if cid == 0: |
| 155 | + # After issuing async_dot, Consumer 0 signals barrier 10 to unblock Consumer 1. |
| 156 | + tlx.named_barrier_arrive(10, 256) |
| 157 | + |
| 158 | + # wait for the MMA using to complete |
| 159 | + qk = tlx.async_dot_wait(0, qk) |
| 160 | + # release the K buffer |
| 161 | + k_empty = tlx.local_view(k_empties, k_buf_id) |
| 162 | + tlx.barrier_arrive(k_empty, 1) |
| 163 | + |
| 164 | + # -- compute m_i and l_i ---- |
| 165 | + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) |
| 166 | + qk = qk * qk_scale - m_ij[:, None] |
| 167 | + p = tl.math.exp2(qk) |
| 168 | + # -- compute correction factor |
| 169 | + alpha = tl.math.exp2(m_i - m_ij) |
| 170 | + # -- update output accumulator[0] -- |
| 171 | + acc = acc * alpha[:, None] |
| 172 | + l_ij = tl.sum(p, 1) |
| 173 | + l_i = l_i * alpha + l_ij |
| 174 | + m_i = m_ij |
| 175 | + acc_cnt = 1 |
| 176 | + |
| 177 | + |
| 178 | + # loop over k, v and update accumulator |
| 179 | + for _ in tl.range(lo + BLOCK_N, hi, BLOCK_N): |
| 180 | + k_buf_id = acc_cnt % NUM_BUFFERS |
| 181 | + # buffers in a row share the same phase |
| 182 | + k_phase = k_phase ^ (k_buf_id == 0) |
| 183 | + |
| 184 | + # wait for the K buffer to be populated by the producer |
| 185 | + k_full = tlx.local_view(k_fulls, k_buf_id) |
| 186 | + tlx.barrier_wait(k_full, k_phase) |
| 187 | + k_tile = tlx.local_view(k_tiles, k_buf_id) |
| 188 | + |
| 189 | + # compute qk for the current iteration |
| 190 | + k_tile = tlx.local_trans(k_tile) |
| 191 | + qk = tlx.async_dot(q_tile, k_tile) |
| 192 | + |
| 193 | + # compute pv from the previous iteration |
| 194 | + # wait for the previous V buffer to be populated by the producer |
| 195 | + v_buf_id = (acc_cnt - 1) % NUM_BUFFERS |
| 196 | + v_phase = v_phase ^ (v_buf_id == 0) |
| 197 | + v_full = tlx.local_view(v_fulls, v_buf_id) |
| 198 | + tlx.barrier_wait(v_full, v_phase) |
| 199 | + v_tile = tlx.local_view(v_tiles, v_buf_id) |
| 200 | + # prepare p and v for the dot |
| 201 | + p = p.to(tlx.dtype_of(desc_k)) |
| 202 | + acc = tlx.async_dot(p, v_tile, acc) |
| 203 | + |
| 204 | + # wait for the current qk MMA to complete |
| 205 | + qk = tlx.async_dot_wait(1, qk) |
| 206 | + # release the K buffer |
| 207 | + k_empty = tlx.local_view(k_empties, k_buf_id) |
| 208 | + tlx.barrier_arrive(k_empty, 1) |
| 209 | + |
| 210 | + # -- compute m_i and l_i ---- |
| 211 | + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) |
| 212 | + qk = qk * qk_scale - m_ij[:, None] |
| 213 | + p = tl.math.exp2(qk) |
| 214 | + # -- compute correction factor |
| 215 | + alpha = tl.math.exp2(m_i - m_ij) |
| 216 | + l_ij = tl.sum(p, 1) |
| 217 | + # update m_i and l_i |
| 218 | + l_i = l_i * alpha + l_ij |
| 219 | + m_i = m_ij |
| 220 | + |
| 221 | + # -- update output accumulator -- |
| 222 | + # wait for the previous pv MMA to complete |
| 223 | + acc = tlx.async_dot_wait(0, acc) |
| 224 | + # release the V buffer |
| 225 | + v_empty = tlx.local_view(v_empties, v_buf_id) |
| 226 | + tlx.barrier_arrive(v_empty, 1) |
| 227 | + acc = acc * alpha[:, None] |
| 228 | + acc_cnt += 1 |
| 229 | + |
| 230 | + # compute pv from the last iteration |
| 231 | + # wait for the V buffer to be populated by the producer |
| 232 | + v_buf_id = (acc_cnt - 1) % NUM_BUFFERS |
| 233 | + v_phase = v_phase ^ (v_buf_id == 0) |
| 234 | + v_full = tlx.local_view(v_fulls, v_buf_id) |
| 235 | + tlx.barrier_wait(v_full, v_phase) |
| 236 | + v_tile = tlx.local_view(v_tiles, v_buf_id) |
| 237 | + # prepare p and v for the dot |
| 238 | + p = p.to(tlx.dtype_of(desc_k)) |
| 239 | + acc = tlx.async_dot(p, v_tile, acc) |
| 240 | + # wait for the MMA using to complete |
| 241 | + acc = tlx.async_dot_wait(0, acc) |
| 242 | + # release the V buffer |
| 243 | + v_empty = tlx.local_view(v_empties, v_buf_id) |
| 244 | + tlx.barrier_arrive(v_empty, 1) |
| 245 | + |
| 246 | + # epilogue |
| 247 | + start_m = tl.program_id(0) |
| 248 | + off_hz = tl.program_id(1) |
| 249 | + off_z = off_hz // H |
| 250 | + off_h = off_hz % H |
| 251 | + offset_y = off_z * (N_CTX * H) + off_h * N_CTX |
| 252 | + qo_offset_y = offset_y + start_m * BLOCK_M |
| 253 | + qo_offset_y_split = qo_offset_y + cid * BLOCK_M_SPLIT |
| 254 | + m_i += tl.math.log2(l_i) |
| 255 | + acc = acc / l_i[:, None] |
| 256 | + offs_m = start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT) |
| 257 | + m_ptrs = M + off_hz * N_CTX + offs_m |
| 258 | + tl.store(m_ptrs, m_i) |
| 259 | + desc_o.store([qo_offset_y_split, 0], acc.to(tlx.dtype_of(desc_o))) |
| 260 | + |
| 261 | + |
| 262 | +class _attention(torch.autograd.Function): |
| 263 | + |
| 264 | + @staticmethod |
| 265 | + def forward(ctx, q, k, v, sm_scale): |
| 266 | + # shape constraints |
| 267 | + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] |
| 268 | + # when v is in float8_e5m2 it is transposed. |
| 269 | + HEAD_DIM_V = v.shape[-1] |
| 270 | + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V |
| 271 | + assert HEAD_DIM_K in {16, 32, 64, 128, 256} |
| 272 | + o = torch.empty_like(q) |
| 273 | + extra_kern_args = {} |
| 274 | + |
| 275 | + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) |
| 276 | + # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor |
| 277 | + y_dim = q.shape[0] * q.shape[1] * q.shape[2] |
| 278 | + |
| 279 | + dummy_block = [1, 1] |
| 280 | + desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) |
| 281 | + if q.dtype == torch.float8_e5m2: |
| 282 | + desc_v = TensorDescriptor(v, shape=[HEAD_DIM_K, y_dim], strides=[q.shape[2], 1], block_shape=dummy_block) |
| 283 | + else: |
| 284 | + desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) |
| 285 | + desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) |
| 286 | + desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) |
| 287 | + |
| 288 | + def alloc_fn(size: int, align: int, _): |
| 289 | + return torch.empty(size, dtype=torch.int8, device="cuda") |
| 290 | + |
| 291 | + triton.set_allocator(alloc_fn) |
| 292 | + |
| 293 | + def grid(META): |
| 294 | + return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) |
| 295 | + |
| 296 | + ctx.grid = grid |
| 297 | + _attn_fwd_ws_pipelined_pingpong[grid]( |
| 298 | + sm_scale, M, # |
| 299 | + q.shape[0], q.shape[1], # |
| 300 | + desc_q, desc_k, desc_v, desc_o, # |
| 301 | + N_CTX=q.shape[2], # |
| 302 | + HEAD_DIM=HEAD_DIM_K, # |
| 303 | + FP8_OUTPUT=q.dtype == torch.float8_e5m2, # |
| 304 | + **extra_kern_args) |
| 305 | + |
| 306 | + ctx.save_for_backward(q, k, v, o, M) |
| 307 | + ctx.sm_scale = sm_scale |
| 308 | + ctx.HEAD_DIM = HEAD_DIM_K |
| 309 | + return o |
| 310 | + |
| 311 | + |
| 312 | +attention = _attention.apply |
0 commit comments