Skip to content

Commit 0daeb4f

Browse files
authored
[Gluon][Tutorial] Optimize causal masking (#7546)
I stole this trick from Dao-AILab/flash-attention@bac1001 with the trick ``` Attention Z=4 H=32 D=64 causal=True: N_CTX triton-fp16 triton-fp8 0 1024.0 275.777864 287.581485 1 2048.0 470.685276 464.429163 2 4096.0 605.601947 586.069183 3 8192.0 716.620863 689.856917 4 16384.0 753.584386 727.499616 5 32768.0 783.395758 753.602188 6 65536.0 802.394545 768.047070 Attention Z=4 H=32 D=128 causal=True: N_CTX triton-fp16 triton-fp8 0 1024.0 501.329981 552.328105 1 2048.0 791.376216 900.740955 2 4096.0 999.411388 1187.012490 3 8192.0 1153.019916 1421.671343 4 16384.0 1191.816450 1521.984476 5 32768.0 1143.861809 1620.830787 6 65536.0 1081.736229 1651.791420 ``` without the trick ``` Attention Z=4 H=32 D=64 causal=True: N_CTX triton-fp16 triton-fp8 0 1024.0 259.354966 268.960170 1 2048.0 449.212808 444.264722 2 4096.0 574.672354 572.764276 3 8192.0 683.267710 680.672054 4 16384.0 721.774534 720.585345 5 32768.0 750.475561 754.164559 6 65536.0 757.656602 769.899298 Attention Z=4 H=32 D=128 causal=True: N_CTX triton-fp16 triton-fp8 0 1024.0 483.921292 508.431048 1 2048.0 773.140056 841.316926 2 4096.0 992.567237 1158.317626 3 8192.0 1138.303982 1384.846389 4 16384.0 1122.413679 1501.296101 5 32768.0 1103.324736 1592.760500 6 65536.0 1037.807974 1634.310242 ```
1 parent 19eef7c commit 0daeb4f

File tree

1 file changed

+47
-11
lines changed

1 file changed

+47
-11
lines changed

python/tutorials/gluon/01-attention-forward.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,18 @@ def issue_async_tma_load(smem, bar, desc, offset):
173173
tma.async_copy_global_to_shared(desc, [offset, 0], bar, smem)
174174

175175

176+
@gluon.jit
177+
def _interleave_n(a, b, size: gl.constexpr, f: gl.constexpr, i: gl.constexpr = 0):
178+
if a.shape[1] == size:
179+
return f(a, b, i)
180+
else:
181+
a0, a1 = a.reshape([a.shape[0], 2, a.shape[1] // 2]).permute(0, 2, 1).split()
182+
b0, b1 = b.reshape([b.shape[0], 2, b.shape[1] // 2]).permute(0, 2, 1).split()
183+
c0 = _interleave_n(a0, b0, size, f, i)
184+
c1 = _interleave_n(a1, b1, size, f, i + a.shape[1] // 2)
185+
return gl.convert_layout(gl.join(c0, c1).permute(0, 2, 1).reshape(a.shape), a.type.layout)
186+
187+
176188
# ===-----------------------------------------------------------------------===#
177189
# Gluon Attention
178190
# ===-----------------------------------------------------------------------===#
@@ -586,22 +598,49 @@ def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
586598
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o_init, mbarriers=[o1_bar, v_bar, s0_bar, s1_bar])
587599

588600

601+
@gluon.jit
602+
def _mask_inner(qk, mask, i: gl.constexpr):
603+
mask_i_bit = mask & (1 << i) == 0
604+
return gl.where(mask_i_bit, qk, -float("inf"))
605+
606+
607+
@gluon.jit
608+
def _mask_frag(qk, col_limit_right, s: gl.constexpr):
609+
col_limit_right_s = col_limit_right - s
610+
col_limit_right_cur = max(col_limit_right_s, 0)
611+
mask = -1 << col_limit_right_cur
612+
return _interleave_n(qk, mask, 1, _mask_inner)
613+
614+
615+
@gluon.jit
616+
def _mask_bits(qk, col_limit_right):
617+
# FIXME: This is a more concise implementation (which compiles faster) but
618+
# it results in slightly slower code due to the lack of interleaving.
619+
offs_n = gl.arange(0, qk.shape[1], layout=gl.SliceLayout(0, qk.type.layout))[None, :]
620+
s = offs_n & ~0xf
621+
i = offs_n & 0xf
622+
623+
col_lim_right_s = col_limit_right - s
624+
col_lim_right_cur = max(col_lim_right_s, 0)
625+
mask = -1 << col_lim_right_cur
626+
mask_i_bit = (mask & (1 << i)) == 0
627+
return gl.where(mask_i_bit, qk, -float("inf"))
628+
629+
589630
@gluon.jit
590631
def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
591632
s_consumer, corr_producer, exp_turnstile, corr_bar, #
592-
offs_m, offs_n, m_i, l_i0, l_i1, STAGE: gl.constexpr):
633+
offs_m, m_i, l_i0, l_i1, STAGE: gl.constexpr):
593634
lo, hi = prog.get_loop_bounds(STAGE)
594635

595636
for start_n in range(lo, hi, config.BLOCK_N):
596637
s_tmem, s_bar, s_consumer = s_consumer.acquire()
597638
qk = s_tmem.load(config.qk_layout)
598639

599640
if STAGE == 2:
600-
# Prevent LLVM from hoisting the partial sums, which triggers spilling.
601-
offs_n = gl.inline_asm_elementwise("mov.b32 $0, $0;", "=r,r", [offs_n], dtype=gl.int32, is_pure=True,
602-
pack=1)
603-
mask = offs_m[:, None] < (start_n + offs_n[None, :])
604-
qk = gl.where(mask, -1.0e8, qk)
641+
col_limit_right = (offs_m - start_n + 1)[:, None].broadcast_to(qk.shape)
642+
qk = _interleave_n(qk, col_limit_right, 16, _mask_frag)
643+
605644
m_ij = gl.maximum(m_i, gl.max(qk, 1) * config.qk_scale)
606645
alpha = gl.exp2(m_i - m_ij)
607646

@@ -682,11 +721,8 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
682721
@gluon.jit
683722
def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr, #
684723
s_chnl, corr_chnl, exp_turnstile):
685-
qk_slice_dim0: gl.constexpr = gl.SliceLayout(0, config.qk_layout)
686724
qk_slice_dim1: gl.constexpr = gl.SliceLayout(1, config.qk_layout)
687725

688-
offs_n = gl.arange(0, config.BLOCK_N, qk_slice_dim0)
689-
690726
s_consumer = s_chnl.create_consumer()
691727
corr_producer = corr_chnl.create_producer()
692728
_, corr_bar, corr_producer = corr_producer.acquire()
@@ -709,11 +745,11 @@ def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,
709745
if STAGE & 1:
710746
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = _softmax_inner_loop( #
711747
tile_id, config, prog, s_consumer, corr_producer, exp_turnstile, corr_bar, #
712-
offs_m, offs_n, m_i, l_i0, l_i1, STAGE=4 - STAGE)
748+
offs_m, m_i, l_i0, l_i1, STAGE=4 - STAGE)
713749
if STAGE & 2:
714750
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = _softmax_inner_loop( #
715751
tile_id, config, prog, s_consumer, corr_producer, exp_turnstile, corr_bar, #
716-
offs_m, offs_n, m_i, l_i0, l_i1, STAGE=2)
752+
offs_m, m_i, l_i0, l_i1, STAGE=2)
717753

718754
if config.use_fadd2_reduce:
719755
l_i = l_i0 + l_i1

0 commit comments

Comments
 (0)