Skip to content

Commit dba8d2b

Browse files
committed
take care of sequence lengths not multiple of block size
1 parent 32da919 commit dba8d2b

File tree

3 files changed

+49
-23
lines changed

3 files changed

+49
-23
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
# taken from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
24
# with fixes for triton 2.3
35

@@ -7,6 +9,7 @@
79

810
import torch
911
from torch import Tensor
12+
import torch.nn.functional as F
1013

1114
from einops import repeat, rearrange, reduce
1215

@@ -22,6 +25,17 @@ def divisible_by(num, den):
2225
def round_up_multiple(n, mult):
2326
return ceil(n / mult) * mult
2427

28+
def pad_at_dim(t, pad: tuple[int, int], *, dim = -1, value = 0.):
29+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
30+
zeros = ((0, 0) * dims_from_right)
31+
return F.pad(t, (*zeros, *pad), value = value)
32+
33+
def pad_to_multiple(t, mult, *, dim):
34+
length = t.shape[dim]
35+
padded_length = round_up_multiple(length, mult)
36+
remainder = padded_length - length
37+
return pad_at_dim(t, (0, remainder), dim = dim)
38+
2539
def is_contiguous(x: Tensor):
2640
return x.stride(-1) == 1
2741

@@ -168,13 +182,13 @@ def forward_kernel(
168182
if EVEN_HEADDIM:
169183
q = tl.load(q_ptrs)
170184
else:
171-
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
185+
q = tl.load(q_ptrs, mask=offs_d[None, None, :] < headdim, other=0.0)
172186
else:
173187
if EVEN_HEADDIM:
174-
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
188+
q = tl.load(q_ptrs, mask=offs_m[None, :, None] < seqlen_q, other=0.0)
175189
else:
176190
q = tl.load(
177-
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
191+
q_ptrs, mask=(offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim), other=0.0
178192
)
179193

180194
q = q.reshape([QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM])
@@ -360,13 +374,13 @@ def forward_kernel(
360374
if EVEN_HEADDIM:
361375
tl.store(out_ptrs, acc_o)
362376
else:
363-
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
377+
tl.store(out_ptrs, acc_o, mask=offs_d[None, None, :] < headdim)
364378
else:
365379
if EVEN_HEADDIM:
366-
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
380+
tl.store(out_ptrs, acc_o, mask=offs_m[None, :, None] < seqlen_q)
367381
else:
368382
tl.store(
369-
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
383+
out_ptrs, acc_o, mask=(offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim)
370384
)
371385

372386
def native_sparse_attn_forward(
@@ -672,11 +686,11 @@ def backward_kernel_one_col_block(
672686
q = tl.load(q_ptrs)
673687
else:
674688
if EVEN_HEADDIM:
675-
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
689+
q = tl.load(q_ptrs, mask=offs_m[None, :, None] < seqlen_q, other=0.0)
676690
else:
677691
q = tl.load(
678692
q_ptrs,
679-
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
693+
mask=(offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
680694
other=0.0,
681695
)
682696
# recompute p = softmax(qk, dim=-1).T
@@ -717,7 +731,7 @@ def backward_kernel_one_col_block(
717731
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
718732
do = tl.load(
719733
do_ptrs,
720-
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
734+
mask=(offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
721735
other=0.0,
722736
)
723737

@@ -887,12 +901,12 @@ def backward_kernel_one_col_block(
887901
tl.atomic_add(dq_ptrs, dq, sem = 'relaxed')
888902
else:
889903
if EVEN_HEADDIM:
890-
tl.atomic_add(dq_ptrs, dq, mask=offs_m[:, None] < seqlen_q, sem = 'relaxed')
904+
tl.atomic_add(dq_ptrs, dq, mask=offs_m[None, :, None] < seqlen_q, sem = 'relaxed')
891905
else:
892906
tl.atomic_add(
893907
dq_ptrs,
894908
dq,
895-
mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
909+
mask = (offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
896910
sem = 'relaxed',
897911
)
898912

@@ -1248,7 +1262,7 @@ def native_sparse_attend(
12481262
fmask,
12491263
return_lse = False
12501264
):
1251-
assert divisible_by(fq.shape[-2], block_size)
1265+
seq_len = fq.shape[-2]
12521266

12531267
out, lse = _native_sparse_attend(
12541268
fq, fk, fv,
@@ -1260,4 +1274,4 @@ def native_sparse_attend(
12601274
if not return_lse:
12611275
return out
12621276

1263-
return out, lse
1277+
return out, lse[..., :seq_len]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.51"
3+
version = "0.0.52"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

test_triton_nsa.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from math import ceil
12
import torch
2-
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend
3+
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend, round_up_multiple, pad_to_multiple
34

45
import einx
56
from einops import rearrange, einsum, repeat
@@ -25,10 +26,12 @@ def regular_attend(
2526
q_heads, seq_len, kv_heads, device = q.shape[1], q.shape[-2], k.shape[1], q.device
2627
assert divisible_by(q_heads, kv_heads)
2728

29+
q, k, v = tuple(pad_to_multiple(t, block_size, dim = -2) for t in (q, k, v))
30+
indices, mask = tuple(pad_to_multiple(t, block_size, dim = -2) for t in (indices, mask))
31+
2832
g = q_heads // kv_heads # `g` stands for `g`roups of query heads per kv head
2933

30-
assert divisible_by(seq_len, block_size)
31-
w = seq_len // block_size
34+
w = ceil(seq_len / block_size)
3235

3336
q, k, v = tuple(rearrange(t, 'b h (w n) d -> b h w n d', n = block_size) for t in (q, k, v))
3437

@@ -83,22 +86,31 @@ def regular_attend(
8386

8487
out = rearrange(out, 'b h g w n d -> b (h g) (w n) d')
8588

89+
out = out[..., :seq_len, :]
90+
8691
if not return_lse:
8792
return out
8893

8994
lse = sim.logsumexp(dim = -1)
90-
return out, rearrange(lse, 'b g h w n -> b (g h) (w n)')
95+
lse = rearrange(lse, 'b g h w n -> b (g h) (w n)')
96+
lse = lse[..., :seq_len]
97+
98+
return out, lse
9199

92100
# mock inputs
93101

102+
seq_len = 511
103+
q_heads = 4
104+
kv_heads = 2
94105
fine_block_size = 16
106+
num_sel = 1
95107

96-
q = torch.randn(2, 4, 512, 64).cuda()
97-
k = torch.randn(2, 2, 512, 64).cuda()
98-
v = torch.randn(2, 2, 512, 64).cuda()
108+
q = torch.randn(2, q_heads, seq_len, 64).cuda()
109+
k = torch.randn(2, kv_heads, seq_len, 64).cuda()
110+
v = torch.randn(2, kv_heads, seq_len, 64).cuda()
99111

100-
indices = torch.zeros(2, 2, 512, 0).long().cuda()
101-
mask = torch.randint(0, 2, (2, 2, 512, 0)).bool().cuda()
112+
indices = torch.zeros(2, kv_heads, seq_len, num_sel).long().cuda()
113+
mask = torch.randint(0, 2, (2, kv_heads, seq_len, num_sel)).bool().cuda()
102114

103115
# both regular and nsa pathways `r` and `n`
104116

0 commit comments

Comments
 (0)