From 6d5389bd4b2a23114d87933b779bda4eb8f064d6 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 30 Jul 2025 16:19:09 -0700 Subject: [PATCH] Add TLX attention (WS pipelined pingpong hopper) --- .../operators/flash_attention/operator.py | 22 + .../tlx_attn_ws_pipelined_pingpong_hopper.py | 383 ++++++++++++++++++ 2 files changed, 405 insertions(+) create mode 100644 tritonbench/operators/flash_attention/tlx_attn_ws_pipelined_pingpong_hopper.py diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 64171ee1..c6e9dcdb 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -136,6 +136,18 @@ except (ImportError, IOError, AttributeError, TypeError): HAS_XFORMERS = False +# [Optional] TLX backend +try: + import triton.language.extra.tlx as tlx + + from .tlx_attn_ws_pipelined_pingpong_hopper import ( + attention as tlx_attn_ws_pipelined_pingpong_hopper, + ) + + HAS_TLX = True +except (ImportError, IOError, AttributeError): + HAS_TLX = False + from typing import Any, Generator, List from tritonbench.utils.input import input_filter @@ -299,6 +311,16 @@ def triton_tutorial_flash_v2_tma( q, k, v, self.causal, self.sm_scale, "tma" ) + @register_benchmark(enabled=HAS_TLX) + def tlx_attn_ws_pipelined_pingpong_hopper( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> Callable: + # TLX flash attention with Hopper optimizations + return lambda: tlx_attn_ws_pipelined_pingpong_hopper(q, k, v, self.sm_scale) + def xformers_preprocess( self, q: torch.Tensor, diff --git a/tritonbench/operators/flash_attention/tlx_attn_ws_pipelined_pingpong_hopper.py b/tritonbench/operators/flash_attention/tlx_attn_ws_pipelined_pingpong_hopper.py new file mode 100644 index 00000000..9129e323 --- /dev/null +++ b/tritonbench/operators/flash_attention/tlx_attn_ws_pipelined_pingpong_hopper.py @@ -0,0 +1,383 @@ +# Copied from https://github.com/facebookexperimental/triton/blob/tlx/third_party/tlx/tutorials/flash-attention-WS-pipelined-pingpong-hopper.py + +import torch + +import triton +import triton.language as tl +import triton.language.extra.tlx as tlx +from triton.tools.tensor_descriptor import TensorDescriptor + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +def _host_descriptor_pre_hook(nargs): + BLOCK_M = nargs["BLOCK_M"] + BLOCK_N = nargs["BLOCK_N"] + HEAD_DIM = nargs["HEAD_DIM"] + if not isinstance(nargs["desc_q"], TensorDescriptor): + return + HEAD_DIM = nargs["HEAD_DIM"] + NUM_MMA_GROUPS = nargs["NUM_MMA_GROUPS"] + BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS + nargs["desc_q"].block_shape = [BLOCK_M_SPLIT, HEAD_DIM] + if nargs["FP8_OUTPUT"]: + nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N] + else: + nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] + nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] + nargs["desc_o"].block_shape = [BLOCK_M_SPLIT, HEAD_DIM] + + +configs = [ + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "NUM_BUFFERS": 2, + "NUM_MMA_WARPS": 8, + "NUM_MMA_GROUPS": 2, + }, + num_stages=0, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), +] + + +@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"]) +@triton.jit +def _attn_fwd_ws_pipelined_pingpong( + sm_scale, + M, # + Z, + H, + desc_q, + desc_k, + desc_v, + desc_o, + N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + NUM_BUFFERS: tl.constexpr, # + NUM_MMA_WARPS: tl.constexpr, # + NUM_MMA_GROUPS: tl.constexpr, # +): + tl.static_assert(BLOCK_N <= HEAD_DIM) + BLOCK_M_SPLIT: tl.constexpr = BLOCK_M // NUM_MMA_GROUPS + + # allocate buffers + q_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, HEAD_DIM), tlx.dtype_of(desc_q), NUM_MMA_GROUPS + ) + k_tiles = tlx.local_alloc((BLOCK_N, HEAD_DIM), tlx.dtype_of(desc_k), NUM_BUFFERS) + v_tiles = tlx.local_alloc((BLOCK_N, HEAD_DIM), tlx.dtype_of(desc_v), NUM_BUFFERS) + + # allocate barriers + q_fulls = tlx.alloc_barriers(num_barriers=NUM_MMA_GROUPS, arrive_count=1) + k_empties = tlx.alloc_barriers( + num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS + ) + k_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1) + v_empties = tlx.alloc_barriers( + num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS + ) + v_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1) + + with tlx.async_tasks(): + # producer group + with tlx.async_task("default"): + # initialize offsets + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + offset_y = off_z * (N_CTX * H) + off_h * N_CTX + qo_offset_y = offset_y + start_m * BLOCK_M + lo, hi = 0, N_CTX + kv_offset_y = offset_y + lo + + # load q: it will stay in SRAM throughout + for cid in tl.range(0, NUM_MMA_GROUPS, loop_unroll_factor=NUM_MMA_GROUPS): + q_full = tlx.local_view(q_fulls, cid) + tlx.barrier_expect_bytes( + q_full, 2 * BLOCK_M_SPLIT * HEAD_DIM + ) # float16 + q_tile = tlx.local_view(q_tiles, cid) + qo_offset_y_split = qo_offset_y + cid * BLOCK_M_SPLIT + tlx.async_descriptor_load( + desc_q, q_tile, [qo_offset_y_split, 0], q_full + ) + + # loop over loading k, v + kv_phase = 0 + acc_cnt = 0 + for _ in tl.range(lo, hi, BLOCK_N): + buf_id = acc_cnt % NUM_BUFFERS + # buffers in a row share the same phase + kv_phase = kv_phase ^ (buf_id == 0) + + # wait for the K buffer to be released by the consumer + k_empty = tlx.local_view(k_empties, buf_id) + tlx.barrier_wait(k_empty, kv_phase) + # load K + k_full = tlx.local_view(k_fulls, buf_id) + k_tile = tlx.local_view(k_tiles, buf_id) + tlx.barrier_expect_bytes(k_full, 2 * BLOCK_N * HEAD_DIM) # float16 + tlx.async_descriptor_load(desc_k, k_tile, [kv_offset_y, 0], k_full) + + # wait for the V buffer to be released by the consumer + v_empty = tlx.local_view(v_empties, buf_id) + tlx.barrier_wait(v_empty, kv_phase) + # load V + v_full = tlx.local_view(v_fulls, buf_id) + v_tile = tlx.local_view(v_tiles, buf_id) + tlx.barrier_expect_bytes(v_full, 2 * BLOCK_N * HEAD_DIM) # float16 + tlx.async_descriptor_load(desc_v, v_tile, [kv_offset_y, 0], v_full) + + kv_offset_y += BLOCK_N + acc_cnt += 1 + + # consumer group + with tlx.async_task( + num_warps=NUM_MMA_WARPS // NUM_MMA_GROUPS, + registers=232, + replicate=NUM_MMA_GROUPS, + ): + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M_SPLIT, HEAD_DIM], dtype=tl.float32) + + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + + # wait for the Q buffer to be populated by the producer + cid: tl.constexpr = tlx.async_task_replica_id() + q_full = tlx.local_view(q_fulls, cid) + tlx.barrier_wait(q_full, 0) + q_tile = tlx.local_view(q_tiles, cid) + + lo, hi = 0, N_CTX + k_phase = 0 + v_phase = 1 + k_buf_id = 0 + v_buf_id = 0 + + # wait for the K[0] buffer to be populated by the producer + k_full = tlx.local_view(k_fulls, k_buf_id) + tlx.barrier_wait(k_full, k_phase) + k_tile = tlx.local_view(k_tiles, k_buf_id) + + # -- compute qk[0] ---- + k_tile = tlx.local_trans(k_tile) + + if cid == 0: + # Consumer 0 waits for Consumer 1 to reach synchronization point at barrier 9. + tlx.named_barrier_wait(9, 256) + else: + # Consumer 1 signals its arrival at barrier 9. + tlx.named_barrier_arrive(9, 256) + # Then waits at barrier 10 until Consumer 0 finishes issuing its async_dot. + tlx.named_barrier_wait(10, 256) + + qk = tlx.async_dot(q_tile, k_tile) + + if cid == 0: + # After issuing async_dot, Consumer 0 signals barrier 10 to unblock Consumer 1. + tlx.named_barrier_arrive(10, 256) + + # wait for the MMA using to complete + qk = tlx.async_dot_wait(0, qk) + # release the K buffer + k_empty = tlx.local_view(k_empties, k_buf_id) + tlx.barrier_arrive(k_empty, 1) + + # -- compute m_i and l_i ---- + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + # -- compute correction factor + alpha = tl.math.exp2(m_i - m_ij) + # -- update output accumulator[0] -- + acc = acc * alpha[:, None] + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + m_i = m_ij + acc_cnt = 1 + + # loop over k, v and update accumulator + for _ in tl.range(lo + BLOCK_N, hi, BLOCK_N): + k_buf_id = acc_cnt % NUM_BUFFERS + # buffers in a row share the same phase + k_phase = k_phase ^ (k_buf_id == 0) + + # wait for the K buffer to be populated by the producer + k_full = tlx.local_view(k_fulls, k_buf_id) + tlx.barrier_wait(k_full, k_phase) + k_tile = tlx.local_view(k_tiles, k_buf_id) + + # compute qk for the current iteration + k_tile = tlx.local_trans(k_tile) + qk = tlx.async_dot(q_tile, k_tile) + + # compute pv from the previous iteration + # wait for the previous V buffer to be populated by the producer + v_buf_id = (acc_cnt - 1) % NUM_BUFFERS + v_phase = v_phase ^ (v_buf_id == 0) + v_full = tlx.local_view(v_fulls, v_buf_id) + tlx.barrier_wait(v_full, v_phase) + v_tile = tlx.local_view(v_tiles, v_buf_id) + # prepare p and v for the dot + p = p.to(tlx.dtype_of(desc_k)) + acc = tlx.async_dot(p, v_tile, acc) + + # wait for the current qk MMA to complete + qk = tlx.async_dot_wait(1, qk) + # release the K buffer + k_empty = tlx.local_view(k_empties, k_buf_id) + tlx.barrier_arrive(k_empty, 1) + + # -- compute m_i and l_i ---- + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + # -- compute correction factor + alpha = tl.math.exp2(m_i - m_ij) + l_ij = tl.sum(p, 1) + # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + # -- update output accumulator -- + # wait for the previous pv MMA to complete + acc = tlx.async_dot_wait(0, acc) + # release the V buffer + v_empty = tlx.local_view(v_empties, v_buf_id) + tlx.barrier_arrive(v_empty, 1) + acc = acc * alpha[:, None] + acc_cnt += 1 + + # compute pv from the last iteration + # wait for the V buffer to be populated by the producer + v_buf_id = (acc_cnt - 1) % NUM_BUFFERS + v_phase = v_phase ^ (v_buf_id == 0) + v_full = tlx.local_view(v_fulls, v_buf_id) + tlx.barrier_wait(v_full, v_phase) + v_tile = tlx.local_view(v_tiles, v_buf_id) + # prepare p and v for the dot + p = p.to(tlx.dtype_of(desc_k)) + acc = tlx.async_dot(p, v_tile, acc) + # wait for the MMA using to complete + acc = tlx.async_dot_wait(0, acc) + # release the V buffer + v_empty = tlx.local_view(v_empties, v_buf_id) + tlx.barrier_arrive(v_empty, 1) + + # epilogue + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + offset_y = off_z * (N_CTX * H) + off_h * N_CTX + qo_offset_y = offset_y + start_m * BLOCK_M + qo_offset_y_split = qo_offset_y + cid * BLOCK_M_SPLIT + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + offs_m = ( + start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT) + ) + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + desc_o.store([qo_offset_y_split, 0], acc.to(tlx.dtype_of(desc_o))) + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sm_scale): + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) + extra_kern_args = {} + + M = torch.empty( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor + y_dim = q.shape[0] * q.shape[1] * q.shape[2] + + dummy_block = [1, 1] + desc_q = TensorDescriptor( + q, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) + if q.dtype == torch.float8_e5m2: + desc_v = TensorDescriptor( + v, + shape=[HEAD_DIM_K, y_dim], + strides=[q.shape[2], 1], + block_shape=dummy_block, + ) + else: + desc_v = TensorDescriptor( + v, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) + desc_k = TensorDescriptor( + k, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) + desc_o = TensorDescriptor( + o, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) + + def alloc_fn(size: int, align: int, _): + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + def grid(META): + return ( + triton.cdiv(q.shape[2], META["BLOCK_M"]), + q.shape[0] * q.shape[1], + 1, + ) + + ctx.grid = grid + _attn_fwd_ws_pipelined_pingpong[grid]( + sm_scale, + M, # + q.shape[0], + q.shape[1], # + desc_q, + desc_k, + desc_v, + desc_o, # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + FP8_OUTPUT=q.dtype == torch.float8_e5m2, # + **extra_kern_args, + ) + + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + return o + + +attention = _attention.apply