Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 39 additions & 172 deletions tritonbench/kernels/gluon_attention_persistent_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.language.nvidia.blackwell import (
allocate_tensor_memory,
float2,
get_tmem_32x32b_reg_layout,
mbarrier,
tcgen05_commit,
Expand Down Expand Up @@ -243,9 +244,7 @@ class AttentionConfig:
alpha_2d_layout: gl.constexpr

num_kv_buffers: gl.constexpr
use_fadd2_reduce: gl.constexpr
use_exp2_turnstile: gl.constexpr
use_ffma2_scale_rowmax: gl.constexpr

def __init__(
self,
Expand Down Expand Up @@ -290,13 +289,13 @@ def __init__(
qk_instr_shape = get_mma_instr_shape(self.qk_shape, gl.float32)
o_instr_shape = get_mma_instr_shape(self.o_shape, gl.float32)
self.qk_tmem_layout = gl.constexpr(
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), unpacked=True)
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1)
)
self.o_tmem_layout = gl.constexpr(
TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), unpacked=True)
TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), col_stride=1)
)
self.p_tmem_layout = gl.constexpr(
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), unpacked=False)
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1)
)

self.qk_layout = gl.constexpr(
Expand All @@ -321,17 +320,13 @@ def __init__(
gl.BlockedLayout([1, 1], [32, 1], [self.num_warps, 1], [0, 1])
)

is_fp16 = dtype.value in [gl.float16, gl.bfloat16]
is_fp16 = self.dtype.value in [gl.float16, gl.bfloat16]
if is_fp16:
self.num_kv_buffers = gl.constexpr(3 if HEAD_DIM == 128 else 6)
else:
self.num_kv_buffers = gl.constexpr(4 if HEAD_DIM == 128 else 8)

self.use_fadd2_reduce = gl.constexpr(HEAD_DIM == 64)
self.use_exp2_turnstile = gl.constexpr(HEAD_DIM == 64)
self.use_ffma2_scale_rowmax = gl.constexpr(
HEAD_DIM == 128 or is_fp16 == (STAGE == 3)
)

@gluon.jit
def get_program(self, pid_m, pid_n):
Expand Down Expand Up @@ -421,113 +416,6 @@ def get_loop_bounds(self, STAGE: gl.constexpr):
return lo, hi


# ===-----------------------------------------------------------------------===#
# float2
# ===-----------------------------------------------------------------------===#


@gluon.jit
def _add_f32x2(a, b):
return gl.inline_asm_elementwise(
"""
{
.reg .b64 ra, rb, rc;
mov.b64 ra, { $2, $3 };
mov.b64 rb, { $4, $5 };
add.f32x2 rc, ra, rb;
mov.b64 { $0, $1 }, rc;
}
""",
"=r,=r,r,r,r,r",
[a, b],
dtype=gl.float32,
is_pure=True,
pack=2,
)


@gluon.jit
def _mul_f32x2(a, b):
return gl.inline_asm_elementwise(
"""
{
.reg .b64 ra, rb, rc;
mov.b64 ra, { $2, $3 };
mov.b64 rb, { $4, $5 };
mul.f32x2 rc, ra, rb;
mov.b64 { $0, $1 }, rc;
}
""",
"=r,=r,r,r,r,r",
[a, b],
dtype=gl.float32,
is_pure=True,
pack=2,
)


@gluon.jit
def _fma_f32x2(a, b, c):
return gl.inline_asm_elementwise(
"""
{
.reg .b64 ra, rb, rc, rd;
mov.b64 ra, { $2, $3 };
mov.b64 rb, { $4, $5 };
mov.b64 rc, { $6, $7 };
fma.rn.f32x2 rd, ra, rb, rc;
mov.b64 { $0, $1 }, rd;
}
""",
"=r,=r,r,r,r,r,r,r",
[a, b, c],
dtype=gl.float32,
is_pure=True,
pack=2,
)


@gluon.jit
def _reduce_fadd2(p0a, p1a, p0b, p1b):
return gl.inline_asm_elementwise(
"""
{
.reg .b64 rc, ra, rb;
mov.b64 ra, { $2, $4 };
mov.b64 rb, { $3, $5 };
add.f32x2 rc, ra, rb;
mov.b64 { $0, $1 }, rc;
}
""",
"=r,=r,r,r,r,r",
[p0a, p0b, p1a, p1b],
dtype=[gl.float32, gl.float32],
is_pure=True,
pack=1,
)


@gluon.jit
def _pairwise_fma_f32x2(a0, b0, c0, a1, b1, c1):
return gl.inline_asm_elementwise(
"""
{
.reg .b64 rd, ra, rb, rc;
mov.b64 ra, { $2, $5 };
mov.b64 rb, { $3, $6 };
mov.b64 rc, { $4, $7 };
fma.rn.f32x2 rd, ra, rb, rc;
mov.b64 { $0, $1 }, rd;
}
""",
"=r,=r,r,r,r,r,r,r",
[a0, b0, c0, a1, b1, c1],
dtype=[gl.float32, gl.float32],
is_pure=True,
pack=1,
)


# ===-----------------------------------------------------------------------===#
# _gluon_attn
# ===-----------------------------------------------------------------------===#
Expand All @@ -542,15 +430,15 @@ def _borrow_s_as_p(config, s_tmem):
@gluon.jit
def _borrow_s_as_alpha(config, s_tmem):
alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1)
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1)
return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout)


@gluon.jit
def _borrow_s_for_epilogue(config, s_tmem):
m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1)
l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1)
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1)
m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
return m_i_tmem, l_i_tmem
Expand Down Expand Up @@ -798,8 +686,7 @@ def _softmax_inner_loop(
corr_bar, #
offs_m,
m_i,
l_i0,
l_i1,
l_i,
STAGE: gl.constexpr,
):
lo, hi = prog.get_loop_bounds(STAGE)
Expand All @@ -821,11 +708,10 @@ def _softmax_inner_loop(
)
mbarrier.arrive(corr_bar, count=1)

if config.use_ffma2_scale_rowmax:
qk = _fma_f32x2(qk, gl.full_like(qk, config.qk_scale), -m_ij[:, None])
else:
qk = _mul_f32x2(qk, gl.full_like(qk, config.qk_scale))
qk = _add_f32x2(qk, -m_ij[:, None])
rowmax = float2.pack(-m_ij[:, None].broadcast_to(qk.shape), axis=1)
qk = float2.pack(qk, axis=1)
qk = float2.fma(qk, float2.full_like(qk, config.qk_scale), rowmax)
qk = float2.unpack(qk, axis=1)

# Force the softmax partitions to take turns in the EX2 section. This
# prevents contention for the EX2 unit and improves utilization.
Expand All @@ -844,24 +730,12 @@ def _softmax_inner_loop(
if config.use_exp2_turnstile:
mbarrier.arrive(exp_bar, count=1)

if config.use_fadd2_reduce:
p0, p1 = _split_n(p)
l_ij0, l_ij1 = gl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2)
# This is a difference of 1 SASS instruction but it dramatically
# affects instruction scheduling.
alpha = gl.convert_layout(alpha, l_i0.type.layout, assert_trivial=True)
if config.dtype == gl.float8e5:
l_i0, l_i1 = _pairwise_fma_f32x2(l_i0, alpha, l_ij0, l_i1, alpha, l_ij1)
else:
l_i0 = l_i0 * alpha + l_ij0
l_i1 = l_i1 * alpha + l_ij1
else:
l_ij = gl.sum(p, axis=1)
l_i0 = l_i0 * alpha + l_ij

l_ij = float2.pack2(*_split_n(p)).sum(axis=1)
alpha = gl.convert_layout(alpha, l_i.value.type.layout, assert_trivial=True)
l_i = float2.fma(l_i, float2.pack2(alpha, alpha), l_ij)
m_i = m_ij

return m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile
return m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile


@gluon.jit
Expand All @@ -876,11 +750,7 @@ def _softmax_tile(
exp_turnstile,
):
qk_slice_dim1: gl.constexpr = gl.SliceLayout(1, config.qk_layout)
sum_layout: gl.constexpr = (
_get_split_n_layout(config.qk_layout)
if config.use_fadd2_reduce
else config.qk_layout
)
sum_layout: gl.constexpr = _get_split_n_layout(config.qk_layout)

s_consumer = s_chnl.create_consumer()
corr_producer = corr_chnl.create_producer()
Expand All @@ -894,17 +764,12 @@ def _softmax_tile(
offs_m += gl.arange(tile_id * config.SPLIT_M, (1 + tile_id) * config.SPLIT_M)

m_i = gl.full([config.SPLIT_M], -float("inf"), gl.float32, qk_slice_dim1)
l_i0 = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout))
# Accumulate into 2 row-sums so the reduction can be performed with FADD2.
if config.use_fadd2_reduce:
l_i1 = gl.full(
[config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout)
)
else:
l_i1 = 0
l_i = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout))
l_i = float2.pack2(l_i, l_i)

if STAGE & 1:
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = (
m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = (
_softmax_inner_loop( #
tile_id,
config,
Expand All @@ -915,13 +780,12 @@ def _softmax_tile(
corr_bar, #
offs_m,
m_i,
l_i0,
l_i1,
l_i,
STAGE=4 - STAGE,
)
)
if STAGE & 2:
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = (
m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = (
_softmax_inner_loop( #
tile_id,
config,
Expand All @@ -932,16 +796,12 @@ def _softmax_tile(
corr_bar, #
offs_m,
m_i,
l_i0,
l_i1,
l_i,
STAGE=2,
)
)

if config.use_fadd2_reduce:
l_i = l_i0 + l_i1
else:
l_i = l_i0
l_i0, l_i1 = float2.unpack2(l_i)
l_i = l_i0 + l_i1

s_tmem, s_bar, s_consumer = s_consumer.acquire()
m_i_tmem, l_i_tmem = _borrow_s_for_epilogue(config, s_tmem)
Expand Down Expand Up @@ -1039,11 +899,14 @@ def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
mbarrier.arrive(corr_bar, count=1)
alpha = gl.convert_layout(alpha.reshape([config.SPLIT_M]), alpha_layout)

alpha = float2.pack(
alpha[:, None].broadcast_to(config.o_shape[0], config.SPLIT_D), axis=1
)
for i in gl.static_range(config.SPLIT_D_FACTOR):
o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D)
o = o_ref.load(config.o_splitn_layout)
o = _mul_f32x2(o, alpha[:, None])
o_ref.store(o)
o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1)
o = o * alpha
o_ref.store(float2.unpack(o, axis=1))
mbarrier.arrive(o_bar, count=1)
return corr_consumer, o_consumer

Expand Down Expand Up @@ -1081,12 +944,16 @@ def _attn_fwd_correction_epilogue(
)
SPLIT_N: gl.constexpr = o_smem.type.shape[1] // SPLIT_N_FACTOR

scale = 1 / l_i
scale = float2.pack(
(1 / l_i)[:, None].broadcast_to(config.o_shape[0], SPLIT_N), axis=1
)
for i in gl.static_range(SPLIT_N_FACTOR):
o_ref = o_tmem.slice(i * SPLIT_N, SPLIT_N)
o = o_ref.load(config.o_splitn_layout)
o = _mul_f32x2(o, scale[:, None])
o_smem.slice(i * SPLIT_N, SPLIT_N, dim=1).store(o.to(config.dtype))
o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1)
o = o * scale
o_smem.slice(i * SPLIT_N, SPLIT_N, dim=1).store(
float2.unpack(o, axis=1).to(config.dtype)
)

fence_async_shared()
mbarrier.arrive(epi_bar, count=1)
Expand Down