Skip to content

Commit 150c274

Browse files
authored
[Gluon][Tutorial] Subtile QK TMEM load (#7655)
This improves non-causal fp8 perf by 40-50 TFLOPS ``` Attention Z=4 H=32 D=128 causal=False: N_CTX triton-fp16 triton-fp8 0 1024.0 850.929442 911.001638 1 2048.0 1154.974053 1237.075799 2 4096.0 1222.287180 1387.649514 3 8192.0 1314.209177 1497.965931 4 16384.0 1229.544372 1581.403686 5 32768.0 1231.508364 1597.774334 6 65536.0 1229.955624 1598.489864 ``` It slightly speeds up D64 perf but not by much.
1 parent 684c7cf commit 150c274

File tree

1 file changed

+77
-38
lines changed

1 file changed

+77
-38
lines changed

python/tutorials/gluon/01-attention-forward.py

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import torch
23
import triton
34
import pytest
@@ -194,6 +195,8 @@ class AttentionConfig:
194195
num_warps: gl.constexpr
195196

196197
SPLIT_D_FACTOR: gl.constexpr
198+
SPLIT_EXP_FACTOR: gl.constexpr
199+
SPLIT_QK_LOAD_FACTOR: gl.constexpr
197200
SPLIT_M: gl.constexpr
198201
SPLIT_D: gl.constexpr
199202

@@ -218,7 +221,7 @@ class AttentionConfig:
218221
use_ffma2_scale_rowmax: gl.constexpr
219222

220223
def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE, dtype,
221-
num_warps, SPLIT_D_FACTOR):
224+
num_warps):
222225
self.qk_scale = qk_scale
223226
self.Z = Z
224227
self.H = H
@@ -232,7 +235,9 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
232235
self.dtype = gl.constexpr(dtype)
233236
self.num_warps = gl.constexpr(num_warps)
234237

235-
self.SPLIT_D_FACTOR = gl.constexpr(SPLIT_D_FACTOR)
238+
self.SPLIT_D_FACTOR = gl.constexpr(2)
239+
self.SPLIT_EXP_FACTOR = 256 // HEAD_DIM
240+
self.SPLIT_QK_LOAD_FACTOR = gl.constexpr(2 if STAGE == 1 else 1)
236241
self.SPLIT_M = gl.constexpr(self.BLOCK_M // 2)
237242
self.SPLIT_D = gl.constexpr(self.HEAD_DIM // self.SPLIT_D_FACTOR)
238243

@@ -488,6 +493,44 @@ def _borrow_s_for_epilogue(config, s_tmem):
488493
return m_i_tmem, l_i_tmem
489494

490495

496+
@gl.constexpr_function
497+
def _get_split_n_layout(layout, SPLIT_FACTOR: gl.constexpr = 2):
498+
layout = copy.deepcopy(layout)
499+
layout.size_per_thread[1] //= SPLIT_FACTOR
500+
return layout
501+
502+
503+
@gluon.jit
504+
def _split_n(x, SPLIT_FACTOR: gl.constexpr = 2):
505+
if SPLIT_FACTOR == 1:
506+
return (x, )
507+
else:
508+
layout: gl.constexpr = _get_split_n_layout(x.type.layout)
509+
x0, x1 = x.reshape([x.shape[0], 2, x.shape[1] // 2]).permute(0, 2, 1).split()
510+
x0 = gl.convert_layout(x0, layout, assert_trivial=True)
511+
x1 = gl.convert_layout(x1, layout, assert_trivial=True)
512+
return _split_n(x0, SPLIT_FACTOR // 2) + _split_n(x1, SPLIT_FACTOR // 2)
513+
514+
515+
@gl.constexpr_function
516+
def _get_join_n_layout(layout, SPLIT_FACTOR: gl.constexpr = 2):
517+
layout = copy.deepcopy(layout)
518+
layout.size_per_thread[1] *= SPLIT_FACTOR
519+
return layout
520+
521+
522+
@gluon.jit
523+
def _join_n(xs):
524+
if len(xs) == 1:
525+
return xs[0]
526+
else:
527+
x0 = _join_n(xs[:len(xs) // 2])
528+
x1 = _join_n(xs[len(xs) // 2:])
529+
layout: gl.constexpr = _get_join_n_layout(x0.type.layout)
530+
x = gl.join(x0, x1).permute(0, 2, 1).reshape([x0.shape[0], x0.shape[1] * 2])
531+
return gl.convert_layout(x, layout, assert_trivial=True)
532+
533+
491534
@gluon.jit
492535
def _attn_fwd_load(config, chnls, descs, M, STAGE: gl.constexpr):
493536
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
@@ -609,6 +652,28 @@ def _apply_causal_mask(qk, col_limit_right):
609652
return gl.map_elementwise(_mask_scalar, qk, col_limit_right, s, i)
610653

611654

655+
@gluon.jit
656+
def _compute_and_store_exp2(config, qk, p_tmem):
657+
SIZE: gl.constexpr = p_tmem.shape[1] // config.SPLIT_EXP_FACTOR
658+
qks = _split_n(qk, config.SPLIT_EXP_FACTOR)
659+
ps = ()
660+
for i in gl.static_range(config.SPLIT_EXP_FACTOR):
661+
p = gl.exp2(qks[i])
662+
p_tmem.slice(i * SIZE, SIZE).store(p.to(config.dtype))
663+
ps = ps + (p, )
664+
return _join_n(ps)
665+
666+
667+
@gluon.jit
668+
def _subtiled_qk_load(config, s_tmem):
669+
SIZE: gl.constexpr = s_tmem.shape[1] // config.SPLIT_QK_LOAD_FACTOR
670+
layout: gl.constexpr = _get_split_n_layout(config.qk_layout, config.SPLIT_QK_LOAD_FACTOR)
671+
qks = ()
672+
for i in gl.static_range(config.SPLIT_QK_LOAD_FACTOR):
673+
qks = qks + (s_tmem.slice(i * SIZE, SIZE).load(layout), )
674+
return _join_n(qks)
675+
676+
612677
@gluon.jit
613678
def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
614679
s_consumer, corr_producer, exp_turnstile, corr_bar, #
@@ -617,7 +682,7 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
617682

618683
for start_n in range(lo, hi, config.BLOCK_N):
619684
s_tmem, s_bar, s_consumer = s_consumer.acquire()
620-
qk = s_tmem.load(config.qk_layout)
685+
qk = _subtiled_qk_load(config, s_tmem)
621686

622687
if STAGE == 2:
623688
col_limit_right = (offs_m - start_n + 1)[:, None]
@@ -635,11 +700,6 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
635700
else:
636701
qk = _mul_f32x2(qk, gl.full_like(qk, config.qk_scale))
637702
qk = _add_f32x2(qk, -m_ij[:, None])
638-
qk0, qk1, = qk.reshape([config.SPLIT_M, 2, config.BLOCK_N // 2]).permute(0, 2, 1).split()
639-
640-
p_tmem = _borrow_s_as_p(config, s_tmem)
641-
BN4: gl.constexpr = config.BLOCK_N // 4
642-
BN2: gl.constexpr = config.BLOCK_N // 2
643703

644704
# Force the softmax partitions to take turns in the EX2 section. This
645705
# prevents contention for the EX2 unit and improves utilization.
@@ -649,49 +709,27 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
649709
# FIXME: When using FADD2 reductions, ptxas misbehaves and spills far
650710
# below the register limit in the FADD2, FMUL2, EX2 section. Subtile by
651711
# 4 to minimize the spilling.
652-
if config.HEAD_DIM == 64:
653-
qk00, qk01 = qk0.reshape([config.SPLIT_M, 2, config.BLOCK_N // 4]).permute(0, 2, 1).split()
654-
p00 = gl.exp2(qk00)
655-
p_tmem.slice(0, BN4).store(p00.to(config.dtype))
656-
p01 = gl.exp2(qk01)
657-
p_tmem.slice(BN4, BN4).store(p01.to(config.dtype))
658-
p0 = gl.join(p00, p01).permute(0, 2, 1).reshape([config.SPLIT_M, config.BLOCK_N // 2])
659-
p0 = gl.convert_layout(p0, config.qk_layout)
660-
else:
661-
p0 = gl.exp2(qk0)
662-
p_tmem.slice(0, BN2).store(p0.to(config.dtype))
663-
664-
if config.HEAD_DIM == 64:
665-
qk10, qk11 = qk1.reshape([config.SPLIT_M, 2, config.BLOCK_N // 4]).permute(0, 2, 1).split()
666-
p10 = gl.exp2(qk10)
667-
p_tmem.slice(2 * BN4, BN4).store(p10.to(config.dtype))
668-
p11 = gl.exp2(qk11)
669-
p_tmem.slice(3 * BN4, BN4).store(p11.to(config.dtype))
670-
p1 = gl.join(p10, p11).permute(0, 2, 1).reshape([config.SPLIT_M, config.BLOCK_N // 2])
671-
p1 = gl.convert_layout(p1, config.qk_layout)
672-
else:
673-
p1 = gl.exp2(qk1)
674-
p_tmem.slice(BN2, BN2).store(p1.to(config.dtype))
712+
p_tmem = _borrow_s_as_p(config, s_tmem)
713+
p = _compute_and_store_exp2(config, qk, p_tmem)
675714

676715
mbarrier.arrive(s_bar, count=1)
677-
678716
_, corr_bar, corr_producer = corr_producer.acquire()
679717

680-
if config.HEAD_DIM == 64:
718+
if config.use_exp2_turnstile:
681719
mbarrier.arrive(exp_bar, count=1)
682720

683721
if config.use_fadd2_reduce:
722+
p0, p1 = _split_n(p)
684723
l_ij0, l_ij1 = gl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2)
685724
# This is a difference of 1 SASS instruction but it dramatically
686725
# affects instruction scheduling.
726+
alpha = gl.convert_layout(alpha, l_i0.type.layout, assert_trivial=True)
687727
if config.dtype == gl.float8e5:
688728
l_i0, l_i1 = _pairwise_fma_f32x2(l_i0, alpha, l_ij0, l_i1, alpha, l_ij1)
689729
else:
690730
l_i0 = l_i0 * alpha + l_ij0
691731
l_i1 = l_i1 * alpha + l_ij1
692732
else:
693-
p = gl.join(p0, p1).permute(0, 2, 1).reshape([config.SPLIT_M, config.BLOCK_N])
694-
p = gl.convert_layout(p, config.qk_layout)
695733
l_ij = gl.sum(p, axis=1)
696734
l_i0 = l_i0 * alpha + l_ij
697735

@@ -704,6 +742,7 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
704742
def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr, #
705743
s_chnl, corr_chnl, exp_turnstile):
706744
qk_slice_dim1: gl.constexpr = gl.SliceLayout(1, config.qk_layout)
745+
sum_layout: gl.constexpr = _get_split_n_layout(config.qk_layout) if config.use_fadd2_reduce else config.qk_layout
707746

708747
s_consumer = s_chnl.create_consumer()
709748
corr_producer = corr_chnl.create_producer()
@@ -717,10 +756,10 @@ def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,
717756
offs_m += gl.arange(tile_id * config.SPLIT_M, (1 + tile_id) * config.SPLIT_M, qk_slice_dim1)
718757

719758
m_i = gl.full([config.SPLIT_M], -float("inf"), gl.float32, qk_slice_dim1)
720-
l_i0 = gl.full([config.SPLIT_M], 0.0, gl.float32, qk_slice_dim1)
759+
l_i0 = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout))
721760
# Accumulate into 2 row-sums so the reduction can be performed with FADD2.
722761
if config.use_fadd2_reduce:
723-
l_i1 = gl.full([config.SPLIT_M], 0.0, gl.float32, qk_slice_dim1)
762+
l_i1 = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout))
724763
else:
725764
l_i1 = 0
726765

@@ -900,7 +939,7 @@ def attention_kernel( #
900939
num_warps: gl.constexpr):
901940
qk_scale = sm_scale * 1.44269504
902941
config = AttentionConfig(qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE, #
903-
dtype, num_warps, SPLIT_D_FACTOR=2)
942+
dtype, num_warps)
904943

905944
q_chnl = get_desc_channel(desc_q, num_buffers=2)
906945
kv_chnl = get_desc_channel(desc_k, num_buffers=config.num_kv_buffers)

0 commit comments

Comments
 (0)