diff --git a/tritonbench/kernels/gluon_attention_persistent_forward.py b/tritonbench/kernels/gluon_attention_persistent_forward.py index a03de449..a6476b22 100644 --- a/tritonbench/kernels/gluon_attention_persistent_forward.py +++ b/tritonbench/kernels/gluon_attention_persistent_forward.py @@ -8,6 +8,7 @@ from triton.experimental.gluon import language as gl from triton.experimental.gluon.language.nvidia.blackwell import ( allocate_tensor_memory, + float2, get_tmem_32x32b_reg_layout, mbarrier, tcgen05_commit, @@ -243,9 +244,7 @@ class AttentionConfig: alpha_2d_layout: gl.constexpr num_kv_buffers: gl.constexpr - use_fadd2_reduce: gl.constexpr use_exp2_turnstile: gl.constexpr - use_ffma2_scale_rowmax: gl.constexpr def __init__( self, @@ -290,13 +289,13 @@ def __init__( qk_instr_shape = get_mma_instr_shape(self.qk_shape, gl.float32) o_instr_shape = get_mma_instr_shape(self.o_shape, gl.float32) self.qk_tmem_layout = gl.constexpr( - TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), unpacked=True) + TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1) ) self.o_tmem_layout = gl.constexpr( - TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), unpacked=True) + TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), col_stride=1) ) self.p_tmem_layout = gl.constexpr( - TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), unpacked=False) + TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1) ) self.qk_layout = gl.constexpr( @@ -321,17 +320,13 @@ def __init__( gl.BlockedLayout([1, 1], [32, 1], [self.num_warps, 1], [0, 1]) ) - is_fp16 = dtype.value in [gl.float16, gl.bfloat16] + is_fp16 = self.dtype.value in [gl.float16, gl.bfloat16] if is_fp16: self.num_kv_buffers = gl.constexpr(3 if HEAD_DIM == 128 else 6) else: self.num_kv_buffers = gl.constexpr(4 if HEAD_DIM == 128 else 8) - self.use_fadd2_reduce = gl.constexpr(HEAD_DIM == 64) self.use_exp2_turnstile = gl.constexpr(HEAD_DIM == 64) - self.use_ffma2_scale_rowmax = gl.constexpr( - HEAD_DIM == 128 or is_fp16 == (STAGE == 3) - ) @gluon.jit def get_program(self, pid_m, pid_n): @@ -421,113 +416,6 @@ def get_loop_bounds(self, STAGE: gl.constexpr): return lo, hi -# ===-----------------------------------------------------------------------===# -# float2 -# ===-----------------------------------------------------------------------===# - - -@gluon.jit -def _add_f32x2(a, b): - return gl.inline_asm_elementwise( - """ - { - .reg .b64 ra, rb, rc; - mov.b64 ra, { $2, $3 }; - mov.b64 rb, { $4, $5 }; - add.f32x2 rc, ra, rb; - mov.b64 { $0, $1 }, rc; - } - """, - "=r,=r,r,r,r,r", - [a, b], - dtype=gl.float32, - is_pure=True, - pack=2, - ) - - -@gluon.jit -def _mul_f32x2(a, b): - return gl.inline_asm_elementwise( - """ - { - .reg .b64 ra, rb, rc; - mov.b64 ra, { $2, $3 }; - mov.b64 rb, { $4, $5 }; - mul.f32x2 rc, ra, rb; - mov.b64 { $0, $1 }, rc; - } - """, - "=r,=r,r,r,r,r", - [a, b], - dtype=gl.float32, - is_pure=True, - pack=2, - ) - - -@gluon.jit -def _fma_f32x2(a, b, c): - return gl.inline_asm_elementwise( - """ - { - .reg .b64 ra, rb, rc, rd; - mov.b64 ra, { $2, $3 }; - mov.b64 rb, { $4, $5 }; - mov.b64 rc, { $6, $7 }; - fma.rn.f32x2 rd, ra, rb, rc; - mov.b64 { $0, $1 }, rd; - } - """, - "=r,=r,r,r,r,r,r,r", - [a, b, c], - dtype=gl.float32, - is_pure=True, - pack=2, - ) - - -@gluon.jit -def _reduce_fadd2(p0a, p1a, p0b, p1b): - return gl.inline_asm_elementwise( - """ - { - .reg .b64 rc, ra, rb; - mov.b64 ra, { $2, $4 }; - mov.b64 rb, { $3, $5 }; - add.f32x2 rc, ra, rb; - mov.b64 { $0, $1 }, rc; - } - """, - "=r,=r,r,r,r,r", - [p0a, p0b, p1a, p1b], - dtype=[gl.float32, gl.float32], - is_pure=True, - pack=1, - ) - - -@gluon.jit -def _pairwise_fma_f32x2(a0, b0, c0, a1, b1, c1): - return gl.inline_asm_elementwise( - """ - { - .reg .b64 rd, ra, rb, rc; - mov.b64 ra, { $2, $5 }; - mov.b64 rb, { $3, $6 }; - mov.b64 rc, { $4, $7 }; - fma.rn.f32x2 rd, ra, rb, rc; - mov.b64 { $0, $1 }, rd; - } - """, - "=r,=r,r,r,r,r,r,r", - [a0, b0, c0, a1, b1, c1], - dtype=[gl.float32, gl.float32], - is_pure=True, - pack=1, - ) - - # ===-----------------------------------------------------------------------===# # _gluon_attn # ===-----------------------------------------------------------------------===# @@ -542,7 +430,7 @@ def _borrow_s_as_p(config, s_tmem): @gluon.jit def _borrow_s_as_alpha(config, s_tmem): alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1) - alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True) + alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1) return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout) @@ -550,7 +438,7 @@ def _borrow_s_as_alpha(config, s_tmem): def _borrow_s_for_epilogue(config, s_tmem): m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1) l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1) - layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True) + layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1) m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout) l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout) return m_i_tmem, l_i_tmem @@ -798,8 +686,7 @@ def _softmax_inner_loop( corr_bar, # offs_m, m_i, - l_i0, - l_i1, + l_i, STAGE: gl.constexpr, ): lo, hi = prog.get_loop_bounds(STAGE) @@ -821,11 +708,10 @@ def _softmax_inner_loop( ) mbarrier.arrive(corr_bar, count=1) - if config.use_ffma2_scale_rowmax: - qk = _fma_f32x2(qk, gl.full_like(qk, config.qk_scale), -m_ij[:, None]) - else: - qk = _mul_f32x2(qk, gl.full_like(qk, config.qk_scale)) - qk = _add_f32x2(qk, -m_ij[:, None]) + rowmax = float2.pack(-m_ij[:, None].broadcast_to(qk.shape), axis=1) + qk = float2.pack(qk, axis=1) + qk = float2.fma(qk, float2.full_like(qk, config.qk_scale), rowmax) + qk = float2.unpack(qk, axis=1) # Force the softmax partitions to take turns in the EX2 section. This # prevents contention for the EX2 unit and improves utilization. @@ -844,24 +730,12 @@ def _softmax_inner_loop( if config.use_exp2_turnstile: mbarrier.arrive(exp_bar, count=1) - if config.use_fadd2_reduce: - p0, p1 = _split_n(p) - l_ij0, l_ij1 = gl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2) - # This is a difference of 1 SASS instruction but it dramatically - # affects instruction scheduling. - alpha = gl.convert_layout(alpha, l_i0.type.layout, assert_trivial=True) - if config.dtype == gl.float8e5: - l_i0, l_i1 = _pairwise_fma_f32x2(l_i0, alpha, l_ij0, l_i1, alpha, l_ij1) - else: - l_i0 = l_i0 * alpha + l_ij0 - l_i1 = l_i1 * alpha + l_ij1 - else: - l_ij = gl.sum(p, axis=1) - l_i0 = l_i0 * alpha + l_ij - + l_ij = float2.pack2(*_split_n(p)).sum(axis=1) + alpha = gl.convert_layout(alpha, l_i.value.type.layout, assert_trivial=True) + l_i = float2.fma(l_i, float2.pack2(alpha, alpha), l_ij) m_i = m_ij - return m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile + return m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile @gluon.jit @@ -876,11 +750,7 @@ def _softmax_tile( exp_turnstile, ): qk_slice_dim1: gl.constexpr = gl.SliceLayout(1, config.qk_layout) - sum_layout: gl.constexpr = ( - _get_split_n_layout(config.qk_layout) - if config.use_fadd2_reduce - else config.qk_layout - ) + sum_layout: gl.constexpr = _get_split_n_layout(config.qk_layout) s_consumer = s_chnl.create_consumer() corr_producer = corr_chnl.create_producer() @@ -894,17 +764,12 @@ def _softmax_tile( offs_m += gl.arange(tile_id * config.SPLIT_M, (1 + tile_id) * config.SPLIT_M) m_i = gl.full([config.SPLIT_M], -float("inf"), gl.float32, qk_slice_dim1) - l_i0 = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout)) # Accumulate into 2 row-sums so the reduction can be performed with FADD2. - if config.use_fadd2_reduce: - l_i1 = gl.full( - [config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout) - ) - else: - l_i1 = 0 + l_i = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout)) + l_i = float2.pack2(l_i, l_i) if STAGE & 1: - m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = ( + m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = ( _softmax_inner_loop( # tile_id, config, @@ -915,13 +780,12 @@ def _softmax_tile( corr_bar, # offs_m, m_i, - l_i0, - l_i1, + l_i, STAGE=4 - STAGE, ) ) if STAGE & 2: - m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = ( + m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = ( _softmax_inner_loop( # tile_id, config, @@ -932,16 +796,12 @@ def _softmax_tile( corr_bar, # offs_m, m_i, - l_i0, - l_i1, + l_i, STAGE=2, ) ) - - if config.use_fadd2_reduce: - l_i = l_i0 + l_i1 - else: - l_i = l_i0 + l_i0, l_i1 = float2.unpack2(l_i) + l_i = l_i0 + l_i1 s_tmem, s_bar, s_consumer = s_consumer.acquire() m_i_tmem, l_i_tmem = _borrow_s_for_epilogue(config, s_tmem) @@ -1039,11 +899,14 @@ def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer): mbarrier.arrive(corr_bar, count=1) alpha = gl.convert_layout(alpha.reshape([config.SPLIT_M]), alpha_layout) + alpha = float2.pack( + alpha[:, None].broadcast_to(config.o_shape[0], config.SPLIT_D), axis=1 + ) for i in gl.static_range(config.SPLIT_D_FACTOR): o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D) - o = o_ref.load(config.o_splitn_layout) - o = _mul_f32x2(o, alpha[:, None]) - o_ref.store(o) + o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1) + o = o * alpha + o_ref.store(float2.unpack(o, axis=1)) mbarrier.arrive(o_bar, count=1) return corr_consumer, o_consumer @@ -1081,12 +944,16 @@ def _attn_fwd_correction_epilogue( ) SPLIT_N: gl.constexpr = o_smem.type.shape[1] // SPLIT_N_FACTOR - scale = 1 / l_i + scale = float2.pack( + (1 / l_i)[:, None].broadcast_to(config.o_shape[0], SPLIT_N), axis=1 + ) for i in gl.static_range(SPLIT_N_FACTOR): o_ref = o_tmem.slice(i * SPLIT_N, SPLIT_N) - o = o_ref.load(config.o_splitn_layout) - o = _mul_f32x2(o, scale[:, None]) - o_smem.slice(i * SPLIT_N, SPLIT_N, dim=1).store(o.to(config.dtype)) + o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1) + o = o * scale + o_smem.slice(i * SPLIT_N, SPLIT_N, dim=1).store( + float2.unpack(o, axis=1).to(config.dtype) + ) fence_async_shared() mbarrier.arrive(epi_bar, count=1)