Skip to content

Commit 324e0fa

Browse files
committed
Update gluon_attention_persistent_forward
The implementation has changed in upstream Triton, update it.
1 parent fe646da commit 324e0fa

File tree

1 file changed

+40
-172
lines changed

1 file changed

+40
-172
lines changed

tritonbench/kernels/gluon_attention_persistent_forward.py

Lines changed: 40 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from triton.experimental.gluon import language as gl
99
from triton.experimental.gluon.language.nvidia.blackwell import (
1010
allocate_tensor_memory,
11+
float2,
1112
get_tmem_32x32b_reg_layout,
1213
mbarrier,
1314
tcgen05_commit,
@@ -69,6 +70,7 @@ def increment(self):
6970

7071

7172
def Channel(T, alloc_fn):
73+
7274
@aggregate
7375
class ChannelType:
7476
mem: T
@@ -243,9 +245,7 @@ class AttentionConfig:
243245
alpha_2d_layout: gl.constexpr
244246

245247
num_kv_buffers: gl.constexpr
246-
use_fadd2_reduce: gl.constexpr
247248
use_exp2_turnstile: gl.constexpr
248-
use_ffma2_scale_rowmax: gl.constexpr
249249

250250
def __init__(
251251
self,
@@ -290,13 +290,13 @@ def __init__(
290290
qk_instr_shape = get_mma_instr_shape(self.qk_shape, gl.float32)
291291
o_instr_shape = get_mma_instr_shape(self.o_shape, gl.float32)
292292
self.qk_tmem_layout = gl.constexpr(
293-
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), unpacked=True)
293+
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1)
294294
)
295295
self.o_tmem_layout = gl.constexpr(
296-
TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), unpacked=True)
296+
TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), col_stride=1)
297297
)
298298
self.p_tmem_layout = gl.constexpr(
299-
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), unpacked=False)
299+
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1)
300300
)
301301

302302
self.qk_layout = gl.constexpr(
@@ -321,17 +321,13 @@ def __init__(
321321
gl.BlockedLayout([1, 1], [32, 1], [self.num_warps, 1], [0, 1])
322322
)
323323

324-
is_fp16 = dtype.value in [gl.float16, gl.bfloat16]
324+
is_fp16 = self.dtype.value in [gl.float16, gl.bfloat16]
325325
if is_fp16:
326326
self.num_kv_buffers = gl.constexpr(3 if HEAD_DIM == 128 else 6)
327327
else:
328328
self.num_kv_buffers = gl.constexpr(4 if HEAD_DIM == 128 else 8)
329329

330-
self.use_fadd2_reduce = gl.constexpr(HEAD_DIM == 64)
331330
self.use_exp2_turnstile = gl.constexpr(HEAD_DIM == 64)
332-
self.use_ffma2_scale_rowmax = gl.constexpr(
333-
HEAD_DIM == 128 or is_fp16 == (STAGE == 3)
334-
)
335331

336332
@gluon.jit
337333
def get_program(self, pid_m, pid_n):
@@ -421,113 +417,6 @@ def get_loop_bounds(self, STAGE: gl.constexpr):
421417
return lo, hi
422418

423419

424-
# ===-----------------------------------------------------------------------===#
425-
# float2
426-
# ===-----------------------------------------------------------------------===#
427-
428-
429-
@gluon.jit
430-
def _add_f32x2(a, b):
431-
return gl.inline_asm_elementwise(
432-
"""
433-
{
434-
.reg .b64 ra, rb, rc;
435-
mov.b64 ra, { $2, $3 };
436-
mov.b64 rb, { $4, $5 };
437-
add.f32x2 rc, ra, rb;
438-
mov.b64 { $0, $1 }, rc;
439-
}
440-
""",
441-
"=r,=r,r,r,r,r",
442-
[a, b],
443-
dtype=gl.float32,
444-
is_pure=True,
445-
pack=2,
446-
)
447-
448-
449-
@gluon.jit
450-
def _mul_f32x2(a, b):
451-
return gl.inline_asm_elementwise(
452-
"""
453-
{
454-
.reg .b64 ra, rb, rc;
455-
mov.b64 ra, { $2, $3 };
456-
mov.b64 rb, { $4, $5 };
457-
mul.f32x2 rc, ra, rb;
458-
mov.b64 { $0, $1 }, rc;
459-
}
460-
""",
461-
"=r,=r,r,r,r,r",
462-
[a, b],
463-
dtype=gl.float32,
464-
is_pure=True,
465-
pack=2,
466-
)
467-
468-
469-
@gluon.jit
470-
def _fma_f32x2(a, b, c):
471-
return gl.inline_asm_elementwise(
472-
"""
473-
{
474-
.reg .b64 ra, rb, rc, rd;
475-
mov.b64 ra, { $2, $3 };
476-
mov.b64 rb, { $4, $5 };
477-
mov.b64 rc, { $6, $7 };
478-
fma.rn.f32x2 rd, ra, rb, rc;
479-
mov.b64 { $0, $1 }, rd;
480-
}
481-
""",
482-
"=r,=r,r,r,r,r,r,r",
483-
[a, b, c],
484-
dtype=gl.float32,
485-
is_pure=True,
486-
pack=2,
487-
)
488-
489-
490-
@gluon.jit
491-
def _reduce_fadd2(p0a, p1a, p0b, p1b):
492-
return gl.inline_asm_elementwise(
493-
"""
494-
{
495-
.reg .b64 rc, ra, rb;
496-
mov.b64 ra, { $2, $4 };
497-
mov.b64 rb, { $3, $5 };
498-
add.f32x2 rc, ra, rb;
499-
mov.b64 { $0, $1 }, rc;
500-
}
501-
""",
502-
"=r,=r,r,r,r,r",
503-
[p0a, p0b, p1a, p1b],
504-
dtype=[gl.float32, gl.float32],
505-
is_pure=True,
506-
pack=1,
507-
)
508-
509-
510-
@gluon.jit
511-
def _pairwise_fma_f32x2(a0, b0, c0, a1, b1, c1):
512-
return gl.inline_asm_elementwise(
513-
"""
514-
{
515-
.reg .b64 rd, ra, rb, rc;
516-
mov.b64 ra, { $2, $5 };
517-
mov.b64 rb, { $3, $6 };
518-
mov.b64 rc, { $4, $7 };
519-
fma.rn.f32x2 rd, ra, rb, rc;
520-
mov.b64 { $0, $1 }, rd;
521-
}
522-
""",
523-
"=r,=r,r,r,r,r,r,r",
524-
[a0, b0, c0, a1, b1, c1],
525-
dtype=[gl.float32, gl.float32],
526-
is_pure=True,
527-
pack=1,
528-
)
529-
530-
531420
# ===-----------------------------------------------------------------------===#
532421
# _gluon_attn
533422
# ===-----------------------------------------------------------------------===#
@@ -542,15 +431,15 @@ def _borrow_s_as_p(config, s_tmem):
542431
@gluon.jit
543432
def _borrow_s_as_alpha(config, s_tmem):
544433
alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1)
545-
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
434+
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1)
546435
return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout)
547436

548437

549438
@gluon.jit
550439
def _borrow_s_for_epilogue(config, s_tmem):
551440
m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1)
552441
l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1)
553-
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
442+
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1)
554443
m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
555444
l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
556445
return m_i_tmem, l_i_tmem
@@ -798,8 +687,7 @@ def _softmax_inner_loop(
798687
corr_bar, #
799688
offs_m,
800689
m_i,
801-
l_i0,
802-
l_i1,
690+
l_i,
803691
STAGE: gl.constexpr,
804692
):
805693
lo, hi = prog.get_loop_bounds(STAGE)
@@ -821,11 +709,10 @@ def _softmax_inner_loop(
821709
)
822710
mbarrier.arrive(corr_bar, count=1)
823711

824-
if config.use_ffma2_scale_rowmax:
825-
qk = _fma_f32x2(qk, gl.full_like(qk, config.qk_scale), -m_ij[:, None])
826-
else:
827-
qk = _mul_f32x2(qk, gl.full_like(qk, config.qk_scale))
828-
qk = _add_f32x2(qk, -m_ij[:, None])
712+
rowmax = float2.pack(-m_ij[:, None].broadcast_to(qk.shape), axis=1)
713+
qk = float2.pack(qk, axis=1)
714+
qk = float2.fma(qk, float2.full_like(qk, config.qk_scale), rowmax)
715+
qk = float2.unpack(qk, axis=1)
829716

830717
# Force the softmax partitions to take turns in the EX2 section. This
831718
# prevents contention for the EX2 unit and improves utilization.
@@ -844,24 +731,12 @@ def _softmax_inner_loop(
844731
if config.use_exp2_turnstile:
845732
mbarrier.arrive(exp_bar, count=1)
846733

847-
if config.use_fadd2_reduce:
848-
p0, p1 = _split_n(p)
849-
l_ij0, l_ij1 = gl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2)
850-
# This is a difference of 1 SASS instruction but it dramatically
851-
# affects instruction scheduling.
852-
alpha = gl.convert_layout(alpha, l_i0.type.layout, assert_trivial=True)
853-
if config.dtype == gl.float8e5:
854-
l_i0, l_i1 = _pairwise_fma_f32x2(l_i0, alpha, l_ij0, l_i1, alpha, l_ij1)
855-
else:
856-
l_i0 = l_i0 * alpha + l_ij0
857-
l_i1 = l_i1 * alpha + l_ij1
858-
else:
859-
l_ij = gl.sum(p, axis=1)
860-
l_i0 = l_i0 * alpha + l_ij
861-
734+
l_ij = float2.pack2(*_split_n(p)).sum(axis=1)
735+
alpha = gl.convert_layout(alpha, l_i.value.type.layout, assert_trivial=True)
736+
l_i = float2.fma(l_i, float2.pack2(alpha, alpha), l_ij)
862737
m_i = m_ij
863738

864-
return m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile
739+
return m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile
865740

866741

867742
@gluon.jit
@@ -876,11 +751,7 @@ def _softmax_tile(
876751
exp_turnstile,
877752
):
878753
qk_slice_dim1: gl.constexpr = gl.SliceLayout(1, config.qk_layout)
879-
sum_layout: gl.constexpr = (
880-
_get_split_n_layout(config.qk_layout)
881-
if config.use_fadd2_reduce
882-
else config.qk_layout
883-
)
754+
sum_layout: gl.constexpr = _get_split_n_layout(config.qk_layout)
884755

885756
s_consumer = s_chnl.create_consumer()
886757
corr_producer = corr_chnl.create_producer()
@@ -894,17 +765,12 @@ def _softmax_tile(
894765
offs_m += gl.arange(tile_id * config.SPLIT_M, (1 + tile_id) * config.SPLIT_M)
895766

896767
m_i = gl.full([config.SPLIT_M], -float("inf"), gl.float32, qk_slice_dim1)
897-
l_i0 = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout))
898768
# Accumulate into 2 row-sums so the reduction can be performed with FADD2.
899-
if config.use_fadd2_reduce:
900-
l_i1 = gl.full(
901-
[config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout)
902-
)
903-
else:
904-
l_i1 = 0
769+
l_i = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout))
770+
l_i = float2.pack2(l_i, l_i)
905771

906772
if STAGE & 1:
907-
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = (
773+
m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = (
908774
_softmax_inner_loop( #
909775
tile_id,
910776
config,
@@ -915,13 +781,12 @@ def _softmax_tile(
915781
corr_bar, #
916782
offs_m,
917783
m_i,
918-
l_i0,
919-
l_i1,
784+
l_i,
920785
STAGE=4 - STAGE,
921786
)
922787
)
923788
if STAGE & 2:
924-
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = (
789+
m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = (
925790
_softmax_inner_loop( #
926791
tile_id,
927792
config,
@@ -932,16 +797,12 @@ def _softmax_tile(
932797
corr_bar, #
933798
offs_m,
934799
m_i,
935-
l_i0,
936-
l_i1,
800+
l_i,
937801
STAGE=2,
938802
)
939803
)
940-
941-
if config.use_fadd2_reduce:
942-
l_i = l_i0 + l_i1
943-
else:
944-
l_i = l_i0
804+
l_i0, l_i1 = float2.unpack2(l_i)
805+
l_i = l_i0 + l_i1
945806

946807
s_tmem, s_bar, s_consumer = s_consumer.acquire()
947808
m_i_tmem, l_i_tmem = _borrow_s_for_epilogue(config, s_tmem)
@@ -1039,11 +900,14 @@ def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
1039900
mbarrier.arrive(corr_bar, count=1)
1040901
alpha = gl.convert_layout(alpha.reshape([config.SPLIT_M]), alpha_layout)
1041902

903+
alpha = float2.pack(
904+
alpha[:, None].broadcast_to(config.o_shape[0], config.SPLIT_D), axis=1
905+
)
1042906
for i in gl.static_range(config.SPLIT_D_FACTOR):
1043907
o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D)
1044-
o = o_ref.load(config.o_splitn_layout)
1045-
o = _mul_f32x2(o, alpha[:, None])
1046-
o_ref.store(o)
908+
o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1)
909+
o = o * alpha
910+
o_ref.store(float2.unpack(o, axis=1))
1047911
mbarrier.arrive(o_bar, count=1)
1048912
return corr_consumer, o_consumer
1049913

@@ -1081,12 +945,16 @@ def _attn_fwd_correction_epilogue(
1081945
)
1082946
SPLIT_N: gl.constexpr = o_smem.type.shape[1] // SPLIT_N_FACTOR
1083947

1084-
scale = 1 / l_i
948+
scale = float2.pack(
949+
(1 / l_i)[:, None].broadcast_to(config.o_shape[0], SPLIT_N), axis=1
950+
)
1085951
for i in gl.static_range(SPLIT_N_FACTOR):
1086952
o_ref = o_tmem.slice(i * SPLIT_N, SPLIT_N)
1087-
o = o_ref.load(config.o_splitn_layout)
1088-
o = _mul_f32x2(o, scale[:, None])
1089-
o_smem.slice(i * SPLIT_N, SPLIT_N, dim=1).store(o.to(config.dtype))
953+
o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1)
954+
o = o * scale
955+
o_smem.slice(i * SPLIT_N, SPLIT_N, dim=1).store(
956+
float2.unpack(o, axis=1).to(config.dtype)
957+
)
1090958

1091959
fence_async_shared()
1092960
mbarrier.arrive(epi_bar, count=1)

0 commit comments

Comments
 (0)