Skip to content

Commit 059bf78

Browse files
committed
more cleanup
1 parent 2f14297 commit 059bf78

File tree

1 file changed

+8
-112
lines changed

1 file changed

+8
-112
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 8 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def _fwd_kernel(
5656
Q,
5757
K,
5858
V,
59-
Bias,
6059
Out,
6160
M,
6261
Lse,
@@ -70,9 +69,6 @@ def _fwd_kernel(
7069
stride_vb,
7170
stride_vh,
7271
stride_vn,
73-
stride_bb,
74-
stride_bh,
75-
stride_bm,
7672
stride_ob,
7773
stride_oh,
7874
stride_om,
@@ -83,7 +79,6 @@ def _fwd_kernel(
8379
headdim,
8480
CACHE_KEY_SEQLEN_Q,
8581
CACHE_KEY_SEQLEN_K,
86-
HAS_BIAS: tl.constexpr,
8782
BLOCK_HEADDIM: tl.constexpr,
8883
EVEN_M: tl.constexpr,
8984
EVEN_N: tl.constexpr,
@@ -110,9 +105,6 @@ def _fwd_kernel(
110105
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
111106
)
112107

113-
if HAS_BIAS:
114-
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
115-
116108
# maximum
117109

118110
m_ptrs = M + off_hb * seqlen_q_rounded + offs_m
@@ -183,22 +175,8 @@ def _fwd_kernel(
183175

184176
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
185177

186-
if HAS_BIAS:
187-
if EVEN_N:
188-
bias = tl.load(b_ptrs + start_n)
189-
else:
190-
bias = tl.load(
191-
b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
192-
)
193-
bias = bias[None, :]
194-
195-
bias = bias.to(tl.float32)
196-
qk = qk * softmax_scale + bias
197-
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
198-
p = tl.exp(qk - m_ij[:, None])
199-
else:
200-
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
201-
p = tl.exp(qk * softmax_scale - m_ij[:, None])
178+
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
179+
p = tl.exp(qk * softmax_scale - m_ij[:, None])
202180

203181
l_ij = tl.sum(p, 1)
204182

@@ -264,7 +242,6 @@ def flash_attn_forward(
264242
q,
265243
k,
266244
v,
267-
bias = None,
268245
o = None,
269246
m = None,
270247
lse = None,
@@ -285,23 +262,6 @@ def flash_attn_forward(
285262

286263
softmax_scale = default(softmax_scale, d ** -0.5)
287264

288-
has_bias = exists(bias)
289-
290-
if has_bias:
291-
assert bias.dtype in [q.dtype, torch.float]
292-
assert bias.is_cuda
293-
294-
if bias.ndim == 2:
295-
bias = repeat(bias, 'b j -> b h i j', h = nheads, i = seqlen_q)
296-
297-
if not is_contiguous(bias):
298-
bias = bias.contiguous()
299-
300-
assert bias.shape[-2:] == (seqlen_q, seqlen_k)
301-
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
302-
303-
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
304-
305265
seqlen_q_rounded = ceil(seqlen_q / 128) * 128
306266

307267
if not exists(lse):
@@ -324,7 +284,6 @@ def flash_attn_forward(
324284
q,
325285
k,
326286
v,
327-
bias,
328287
o,
329288
m,
330289
lse,
@@ -338,7 +297,6 @@ def flash_attn_forward(
338297
v.stride(0),
339298
v.stride(2),
340299
v.stride(1),
341-
*bias_strides,
342300
o.stride(0),
343301
o.stride(2),
344302
o.stride(1),
@@ -349,7 +307,6 @@ def flash_attn_forward(
349307
d,
350308
seqlen_q // 32,
351309
seqlen_k // 32,
352-
has_bias,
353310
BLOCK_HEADDIM,
354311
BLOCK_M = BLOCK,
355312
BLOCK_N = BLOCK,
@@ -445,7 +402,6 @@ def _bwd_kernel_one_col_block(
445402
Q,
446403
K,
447404
V,
448-
Bias,
449405
DO,
450406
DQ,
451407
DK,
@@ -456,7 +412,6 @@ def _bwd_kernel_one_col_block(
456412
stride_qm,
457413
stride_kn,
458414
stride_vn,
459-
stride_bm,
460415
stride_dom,
461416
stride_dqm,
462417
stride_dkn,
@@ -465,7 +420,6 @@ def _bwd_kernel_one_col_block(
465420
seqlen_k,
466421
headdim,
467422
ATOMIC_ADD: tl.constexpr,
468-
BIAS_TYPE: tl.constexpr,
469423
BLOCK_HEADDIM: tl.constexpr,
470424
EVEN_M: tl.constexpr,
471425
EVEN_N: tl.constexpr,
@@ -486,10 +440,7 @@ def _bwd_kernel_one_col_block(
486440
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
487441
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
488442
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
489-
if BIAS_TYPE == "vector":
490-
b_ptrs = Bias + offs_n
491-
elif BIAS_TYPE == "matrix":
492-
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
443+
493444
# initialize dv and dk
494445
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
495446
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
@@ -562,33 +513,14 @@ def _bwd_kernel_one_col_block(
562513

563514
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
564515

565-
if BIAS_TYPE != "none":
566-
tl.debug_barrier() # Race condition otherwise
567-
if BIAS_TYPE == "vector":
568-
if EVEN_N:
569-
bias = tl.load(b_ptrs).to(tl.float32)
570-
else:
571-
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
572-
bias = bias[None, :]
573-
elif BIAS_TYPE == "matrix":
574-
if EVEN_M & EVEN_N:
575-
bias = tl.load(b_ptrs).to(tl.float32)
576-
else:
577-
bias = tl.load(
578-
b_ptrs,
579-
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
580-
other=0.0,
581-
).to(tl.float32)
582-
qk = qk * softmax_scale + bias
583516
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
584517
# Also wrong for headdim=64.
585518
if not (EVEN_M & EVEN_HEADDIM):
586519
tl.debug_barrier()
587520
lse_i = tl.load(LSE + offs_m_curr)
588-
if BIAS_TYPE == "none":
589-
p = tl.exp(qk * softmax_scale - lse_i[:, None])
590-
else:
591-
p = tl.exp(qk - lse_i[:, None])
521+
522+
p = tl.exp(qk * softmax_scale - lse_i[:, None])
523+
592524
# compute dv
593525
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
594526
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
@@ -693,8 +625,7 @@ def _bwd_kernel_one_col_block(
693625
dq_ptrs += BLOCK_M * stride_dqm
694626
q_ptrs += BLOCK_M * stride_qm
695627
do_ptrs += BLOCK_M * stride_dom
696-
if BIAS_TYPE == "matrix":
697-
b_ptrs += BLOCK_M * stride_bm
628+
698629
# write-back
699630
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
700631
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
@@ -738,7 +669,7 @@ def init_to_zero(name):
738669
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
739670
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
740671
],
741-
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "BLOCK_HEADDIM"],
672+
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BLOCK_HEADDIM"],
742673
)
743674
@triton.heuristics(
744675
{
@@ -752,7 +683,6 @@ def _bwd_kernel(
752683
Q,
753684
K,
754685
V,
755-
Bias,
756686
DO,
757687
DQ,
758688
DK,
@@ -769,9 +699,6 @@ def _bwd_kernel(
769699
stride_vb,
770700
stride_vh,
771701
stride_vn,
772-
stride_bb,
773-
stride_bh,
774-
stride_bm,
775702
stride_dob,
776703
stride_doh,
777704
stride_dom,
@@ -791,7 +718,6 @@ def _bwd_kernel(
791718
headdim,
792719
CACHE_KEY_SEQLEN_Q,
793720
CACHE_KEY_SEQLEN_K,
794-
BIAS_TYPE: tl.constexpr,
795721
BLOCK_HEADDIM: tl.constexpr,
796722
SEQUENCE_PARALLEL: tl.constexpr,
797723
EVEN_M: tl.constexpr,
@@ -811,8 +737,6 @@ def _bwd_kernel(
811737
DQ += off_b * stride_dqb + off_h * stride_dqh
812738
DK += off_b * stride_dkb + off_h * stride_dkh
813739
DV += off_b * stride_dvb + off_h * stride_dvh
814-
if BIAS_TYPE != "none":
815-
Bias += off_b * stride_bb + off_h * stride_bh
816740
# pointer to row-wise quantities in value-like data
817741
D += off_hb * seqlen_q_rounded
818742
LSE += off_hb * seqlen_q_rounded
@@ -824,7 +748,6 @@ def _bwd_kernel(
824748
Q,
825749
K,
826750
V,
827-
Bias,
828751
DO,
829752
DQ,
830753
DK,
@@ -835,7 +758,6 @@ def _bwd_kernel(
835758
stride_qm,
836759
stride_kn,
837760
stride_vn,
838-
stride_bm,
839761
stride_dom,
840762
stride_dqm,
841763
stride_dkn,
@@ -844,7 +766,6 @@ def _bwd_kernel(
844766
seqlen_k,
845767
headdim,
846768
ATOMIC_ADD=False,
847-
BIAS_TYPE=BIAS_TYPE,
848769
BLOCK_HEADDIM=BLOCK_HEADDIM,
849770
EVEN_M=EVEN_M,
850771
EVEN_N=EVEN_N,
@@ -859,7 +780,6 @@ def _bwd_kernel(
859780
Q,
860781
K,
861782
V,
862-
Bias,
863783
DO,
864784
DQ,
865785
DK,
@@ -870,7 +790,6 @@ def _bwd_kernel(
870790
stride_qm,
871791
stride_kn,
872792
stride_vn,
873-
stride_bm,
874793
stride_dom,
875794
stride_dqm,
876795
stride_dkn,
@@ -879,7 +798,6 @@ def _bwd_kernel(
879798
seqlen_k,
880799
headdim,
881800
ATOMIC_ADD=True,
882-
BIAS_TYPE=BIAS_TYPE,
883801
BLOCK_HEADDIM=BLOCK_HEADDIM,
884802
EVEN_M=EVEN_M,
885803
EVEN_N=EVEN_N,
@@ -899,7 +817,6 @@ def flash_attn_backward(
899817
dk,
900818
dv,
901819
delta = None,
902-
bias = None,
903820
softmax_scale = None,
904821
):
905822
# Make sure that the last dimension is contiguous
@@ -944,24 +861,6 @@ def flash_attn_backward(
944861
BLOCK_HEADDIM=BLOCK_HEADDIM,
945862
)
946863

947-
has_bias = bias is not None
948-
bias_type = "none"
949-
if has_bias:
950-
assert bias.dtype in [q.dtype, torch.float]
951-
assert bias.is_cuda
952-
assert bias.dim() == 4
953-
assert bias.stride(-1) == 1
954-
if bias.shape[2:] == (1, seqlen_k):
955-
bias_type = "vector"
956-
elif bias.shape[2:] == (seqlen_q, seqlen_k):
957-
bias_type = "matrix"
958-
else:
959-
raise RuntimeError(
960-
"Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
961-
)
962-
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
963-
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
964-
965864
# BLOCK_M = 128
966865
# BLOCK_N = 64
967866
# num_warps = 4
@@ -973,7 +872,6 @@ def flash_attn_backward(
973872
q,
974873
k,
975874
v,
976-
bias,
977875
do,
978876
dq_accum,
979877
dk,
@@ -990,7 +888,6 @@ def flash_attn_backward(
990888
v.stride(0),
991889
v.stride(2),
992890
v.stride(1),
993-
*bias_strides,
994891
do.stride(0),
995892
do.stride(2),
996893
do.stride(1),
@@ -1012,7 +909,6 @@ def flash_attn_backward(
1012909
seqlen_k // 32, # key for triton cache (limit number of compilations)
1013910
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
1014911
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
1015-
bias_type,
1016912
BLOCK_HEADDIM,
1017913
# SEQUENCE_PARALLEL=False,
1018914
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,

0 commit comments

Comments
 (0)