diff --git a/tritonbench/kernels/blackwell_triton_fused_attention.py b/tritonbench/kernels/blackwell_triton_fused_attention.py index 1e8c4fe5..f5e401ed 100644 --- a/tritonbench/kernels/blackwell_triton_fused_attention.py +++ b/tritonbench/kernels/blackwell_triton_fused_attention.py @@ -17,6 +17,8 @@ import triton.language as tl from triton.tools.tensor_descriptor import TensorDescriptor +from .attention_utils import WITH_MAXNREG + from .blackwell_attention_utils import ( is_blackwell, is_cuda, @@ -27,6 +29,22 @@ ) +# Check if Triton version supports minRegAutoWS and maxRegAutoWS +# These parameters are only available in triton/tree/ws-3.5 +def _supports_reg_auto_ws(): + """Check if the current Triton version supports minRegAutoWS/maxRegAutoWS""" + try: + # Try to create a Config with minRegAutoWS to test support + test_config = triton.Config({}, minRegAutoWS=24, maxRegAutoWS=152) + return True + except (TypeError, AttributeError): + # Parameter not supported in this Triton version + return False + + +HAS_REG_AUTO_WS = _supports_reg_auto_ws() + + @triton.jit def _attn_fwd_subtile( q, @@ -35,13 +53,16 @@ def _attn_fwd_subtile( start_n, offs_n, qk_scale, - l_i, + l_i0, + l_i1, # used when FADD2_REDUCE is true m_i, acc, v, dtype: tl.constexpr, STAGE: tl.constexpr, SUBTILING: tl.constexpr, + VECT_MUL: tl.constexpr, + FADD2_REDUCE: tl.constexpr, ): qk = tl.dot(q, k) if STAGE == 2: @@ -51,11 +72,15 @@ def _attn_fwd_subtile( qk -= m_ij[:, None] else: m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_ij[:, None] + if VECT_MUL & 2: + qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) + else: + qk = qk * qk_scale - m_ij[:, None] p = tl.math.exp2(qk) # -- compute correction factor alpha = tl.math.exp2(m_i - m_ij) - l_ij = tl.sum(p, 1) + if not FADD2_REDUCE: + l_ij = tl.sum(p, 1) # -- update output accumulator -- BM: tl.constexpr = acc.shape[0] @@ -63,28 +88,42 @@ def _attn_fwd_subtile( if SUBTILING: acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() - acc0 = acc0 * alpha[:, None] - acc1 = acc1 * alpha[:, None] + if VECT_MUL & 1: + acc0 = _mul_f32x2(acc0, alpha[:, None]) + acc1 = _mul_f32x2(acc1, alpha[:, None]) + else: + acc0 = acc0 * alpha[:, None] + acc1 = acc1 * alpha[:, None] acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) else: acc = acc * alpha[:, None] + PM: tl.constexpr = p.shape[0] + PN: tl.constexpr = p.shape[1] + if FADD2_REDUCE: + p0, p1 = p.reshape([PM, 2, PN // 2]).permute(0, 2, 1).split() + l_ij0, l_ij1 = tl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2) + l_i0 = l_i0 * alpha + l_ij0 + l_i1 = l_i1 * alpha + l_ij1 + # prepare p and v for the dot p = p.to(dtype) # note that this non transposed v for FP8 is only supported on Blackwell acc = tl.dot(p, v, acc) # update m_i and l_i # place this at the end of the loop to reduce register pressure - l_i = l_i * alpha + l_ij + if not FADD2_REDUCE: + l_i0 = l_i0 * alpha + l_ij m_i = m_ij - return l_i, m_i, acc + return l_i0, l_i1, m_i, acc @triton.jit def _attn_fwd_inner_oss_dp( acc0, l_i0, + l_i0_1, m_i0, q0, desc_k, @@ -102,6 +141,8 @@ def _attn_fwd_inner_oss_dp( N_CTX: tl.constexpr, warp_specialize: tl.constexpr, SUBTILING: tl.constexpr, + VECT_MUL: tl.constexpr, + FADD2_REDUCE: tl.constexpr, ): # range of values handled by this stage if STAGE == 1: @@ -123,7 +164,7 @@ def _attn_fwd_inner_oss_dp( k = desc_k.load([offsetkv_y, 0]).T v = desc_v.load([offsetkv_y, 0]) - l_i0, m_i0, acc0 = _attn_fwd_subtile( + l_i0, l_i0_1, m_i0, acc0 = _attn_fwd_subtile( q0, k, offs_m0, @@ -131,17 +172,20 @@ def _attn_fwd_inner_oss_dp( offs_n, qk_scale, l_i0, + l_i0_1, m_i0, acc0, v, dtype, STAGE, SUBTILING, + VECT_MUL, + FADD2_REDUCE, ) offsetkv_y += BLOCK_N - return acc0, l_i0, m_i0 + return acc0, l_i0, l_i0_1, m_i0 def _host_descriptor_pre_hook(nargs): @@ -167,31 +211,68 @@ def _host_descriptor_pre_hook(nargs): NUM_STAGES_OPTIONS = [3] if is_tile_enabled(): + # Helper to build config with optional minRegAutoWS/maxRegAutoWS + def make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce): + config_kwargs = { + "BLOCK_M": BM, + "BLOCK_N": BN, + "occupancy": occ, + "SUBTILING": subtile, + "VECT_MUL": vectmul, + "FADD2_REDUCE": add2reduce, + } + extra_kwargs = {"pre_hook": _host_descriptor_pre_hook} + + # Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5) + if HAS_REG_AUTO_WS: + extra_kwargs["minRegAutoWS"] = 24 + extra_kwargs["maxRegAutoWS"] = 152 + + return triton.Config(config_kwargs, **extra_kwargs) + configs = [ - triton.Config( - {"BLOCK_M": BM, "BLOCK_N": BN, "SUBTILING": subtile, "occupancy": occ}, - pre_hook=_host_descriptor_pre_hook, - ) + make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce) for BM in [64, 128, 256] for BN in [64, 128] for occ in [1, 2] - for subtile in [False] + for subtile in [True] + for vectmul in [0] + for add2reduce in [False] ] else: + # Helper to build config with optional minRegAutoWS/maxRegAutoWS + def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg): + config_kwargs = { + "BLOCK_M": BM, + "BLOCK_N": BN, + "SUBTILING": subtile, + "VECT_MUL": vectmul, + "FADD2_REDUCE": add2reduce, + } + extra_kwargs = { + "num_stages": s, + "num_warps": w, + "pre_hook": _host_descriptor_pre_hook, + } + + # Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5) + if HAS_REG_AUTO_WS: + extra_kwargs["minRegAutoWS"] = 24 + extra_kwargs["maxRegAutoWS"] = maxreg + extra_kwargs["data_partition_factor"] = 2 + + return triton.Config(config_kwargs, **extra_kwargs) + configs = [ - triton.Config( - {"BLOCK_M": BM, "BLOCK_N": BN, "SUBTILING": subtile}, - num_stages=s, - num_warps=w, - data_partition_factor=2, - pre_hook=_host_descriptor_pre_hook, - # ir_override=f"/home/mren/OpenSource/tritonbench/override/_attn_fwd_persist.ttgir" - ) + make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg) for BM in [256] - for BN in [128] + for BN in [64, 128] for s in NUM_STAGES_OPTIONS for w in [4] - for subtile in [False] + for subtile in [True] + for vectmul in [1] + for add2reduce in [False] + for maxreg in [152, 192] ] @@ -221,6 +302,67 @@ def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) +@triton.jit +def _mul_f32x2(a, b): + return tl.inline_asm_elementwise( + """ + { + .reg .b64 ra, rb, rc; + mov.b64 ra, { $2, $3 }; + mov.b64 rb, { $4, $5 }; + mul.f32x2 rc, ra, rb; + mov.b64 { $0, $1 }, rc; + } + """, + "=r,=r,r,r,r,r", + [a, b], + dtype=tl.float32, + is_pure=True, + pack=2, + ) + + +@triton.jit +def _fma_f32x2(a, b, c): + return tl.inline_asm_elementwise( + """ + { + .reg .b64 ra, rb, rc, rd; + mov.b64 ra, { $2, $3 }; + mov.b64 rb, { $4, $5 }; + mov.b64 rc, { $6, $7 }; + fma.rn.f32x2 rd, ra, rb, rc; + mov.b64 { $0, $1 }, rd; + } + """, + "=r,=r,r,r,r,r,r,r", + [a, b, c], + dtype=tl.float32, + is_pure=True, + pack=2, + ) + + +@triton.jit +def _reduce_fadd2(p0a, p1a, p0b, p1b): + return tl.inline_asm_elementwise( + """ + { + .reg .b64 rc, ra, rb; + mov.b64 ra, { $2, $4 }; + mov.b64 rb, { $3, $5 }; + add.f32x2 rc, ra, rb; + mov.b64 { $0, $1 }, rc; + } + """, + "=r,=r,r,r,r,r", + [p0a, p0b, p1a, p1b], + dtype=[tl.float32, tl.float32], + is_pure=True, + pack=1, + ) + + @triton.jit def _attn_fwd_tma_dp( sm_scale, @@ -233,7 +375,7 @@ def _attn_fwd_tma_dp( desc_o, pid, off_hz, - N_CTX, # + N_CTX: tl.constexpr, # HEAD_DIM: tl.constexpr, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # @@ -242,8 +384,9 @@ def _attn_fwd_tma_dp( warp_specialize: tl.constexpr, # dtype: tl.constexpr, SUBTILING: tl.constexpr, + VECT_MUL: tl.constexpr, + FADD2_REDUCE: tl.constexpr, ): - tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = pid # tl.program_id(0) # off_hz = tl.program_id(1) off_z = off_hz // H @@ -256,7 +399,7 @@ def _attn_fwd_tma_dp( offs_n = tl.arange(0, BLOCK_N) m_i0 = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i0 = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + l_i0_0 = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc0 = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) qk_scale = sm_scale @@ -264,10 +407,16 @@ def _attn_fwd_tma_dp( q0 = desc_q.load([qo_offset_y, 0]) + if FADD2_REDUCE: + l_i0_1 = tl.zeros([BLOCK_M // 2], dtype=tl.float32) + else: + l_i0_1 = 0 + if STAGE & 1: - acc0, l_i0, m_i0 = _attn_fwd_inner_oss_dp( + acc0, l_i0_0, l_i0_1, m_i0 = _attn_fwd_inner_oss_dp( acc0, - l_i0, + l_i0_0, + l_i0_1, m_i0, q0, desc_k, @@ -285,11 +434,14 @@ def _attn_fwd_tma_dp( N_CTX, # warp_specialize, SUBTILING, + VECT_MUL, + FADD2_REDUCE, ) if STAGE & 2: - acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp( + acc0, l_i0_0, l_i0_1, m_i0 = _attn_fwd_inner_oss_dp( acc0, - l_i0, + l_i0_0, + l_i0_1, m_i0, q0, desc_k, @@ -307,8 +459,15 @@ def _attn_fwd_tma_dp( N_CTX, # warp_specialize, SUBTILING, + VECT_MUL, + FADD2_REDUCE, ) + if FADD2_REDUCE: + l_i0 = l_i0_0 + l_i0_1 + else: + l_i0 = l_i0_0 + m_i0 += tl.math.log2(l_i0) acc0 = acc0 / l_i0[:, None] m_ptrs0 = M + off_hz * N_CTX + offs_m0 @@ -331,7 +490,7 @@ def _attn_fwd( desc_k, desc_v, desc_o, - N_CTX, # + N_CTX: tl.constexpr, # HEAD_DIM: tl.constexpr, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # @@ -340,6 +499,8 @@ def _attn_fwd( warp_specialize: tl.constexpr, # dtype: tl.constexpr, SUBTILING: tl.constexpr, + VECT_MUL: tl.constexpr, + FADD2_REDUCE: tl.constexpr, ): pid = tl.program_id(0) off_hz = tl.program_id(1) @@ -363,6 +524,8 @@ def _attn_fwd( warp_specialize, dtype, SUBTILING, + VECT_MUL, + FADD2_REDUCE, ) @@ -381,7 +544,7 @@ def _attn_fwd_persist( desc_k, desc_v, desc_o, - N_CTX, # + N_CTX: tl.constexpr, # HEAD_DIM: tl.constexpr, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # @@ -391,6 +554,8 @@ def _attn_fwd_persist( OUTER_LOOP: tl.constexpr, dtype: tl.constexpr, SUBTILING: tl.constexpr, + VECT_MUL: tl.constexpr, + FADD2_REDUCE: tl.constexpr, ): n_tile_num = tl.cdiv(N_CTX, BLOCK_M) prog_id = tl.program_id(0) @@ -452,6 +617,8 @@ def _attn_fwd_persist( warp_specialize and not OUTER_LOOP, dtype, SUBTILING, + VECT_MUL, + FADD2_REDUCE, ) tile_idx += num_progs @@ -516,13 +683,13 @@ def grid_debug(META): ctx.grid = grid persistent = baseVariant == "persistent" or baseVariant == "ws_persistent" - if is_blackwell() and warp_specialize: + if WITH_MAXNREG and is_blackwell() and warp_specialize: if HEAD_DIM_K == 128 and ( q.dtype == torch.float16 or q.dtype == torch.bfloat16 ): extra_kern_args["maxnreg"] = 128 else: - extra_kern_args["maxnreg"] = 80 + extra_kern_args["maxnreg"] = 128 if persistent: _attn_fwd_persist[grid_persist]( sm_scale, diff --git a/tritonbench/kernels/blackwell_triton_fused_attention_dp.py b/tritonbench/kernels/blackwell_triton_fused_attention_dp.py index d7f1708b..d4989088 100644 --- a/tritonbench/kernels/blackwell_triton_fused_attention_dp.py +++ b/tritonbench/kernels/blackwell_triton_fused_attention_dp.py @@ -72,7 +72,7 @@ def _attn_fwd_subtile( qk -= m_ij[:, None] else: m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - if VECT_MUL: + if VECT_MUL & 2: qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) else: qk = qk * qk_scale - m_ij[:, None] @@ -88,7 +88,7 @@ def _attn_fwd_subtile( if SUBTILING: acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() - if VECT_MUL: + if VECT_MUL & 1: acc0 = _mul_f32x2(acc0, alpha[:, None]) acc1 = _mul_f32x2(acc1, alpha[:, None]) else: @@ -262,12 +262,12 @@ def make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce): for BN in [64, 128] for occ in [1, 2] for subtile in [True] - for vectmul in [False] + for vectmul in [0] for add2reduce in [False] ] else: # Helper to build config with optional minRegAutoWS/maxRegAutoWS - def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce): + def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg): config_kwargs = { "BLOCK_M": BM, "BLOCK_N": BN, @@ -284,19 +284,20 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce): # Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5) if HAS_REG_AUTO_WS: extra_kwargs["minRegAutoWS"] = 24 - extra_kwargs["maxRegAutoWS"] = 152 + extra_kwargs["maxRegAutoWS"] = maxreg return triton.Config(config_kwargs, **extra_kwargs) configs = [ - make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce) + make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg) for BM in [256] for BN in [64, 128] for s in NUM_STAGES_OPTIONS for w in [4] for subtile in [True] - for vectmul in [False] + for vectmul in [1] for add2reduce in [False] + for maxreg in [152, 192] ]