Skip to content

Commit 446a033

Browse files
committed
expand key / value by the head group size for now, when query_heads_share_selection is turned on
1 parent 6bf3a1c commit 446a033

File tree

5 files changed

+27
-13
lines changed

5 files changed

+27
-13
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ def forward(
443443
gates = repeat(gates, 'b h ... -> b (h qh) ...', qh = fine_num_grouped_queries)
444444

445445
if self.use_triton_kernel and not disable_triton_kernel:
446+
446447
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend
447448

448449
fmask = selected_importance_values > 1e-10

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,6 @@ def backward_kernel_one_col_block(
578578
seqlen_k,
579579
seqlen_q_rounded,
580580
headdim,
581-
ATOMIC_ADD: tl.constexpr,
582581
BLOCK_HEADDIM: tl.constexpr,
583582
EVEN_M: tl.constexpr,
584583
EVEN_N: tl.constexpr,
@@ -1053,7 +1052,6 @@ def backward_kernel(
10531052
seqlen_k,
10541053
seqlen_q_rounded,
10551054
headdim,
1056-
ATOMIC_ADD = False,
10571055
BLOCK_HEADDIM = BLOCK_HEADDIM,
10581056
EVEN_M = EVEN_M,
10591057
EVEN_N = EVEN_N,
@@ -1263,6 +1261,15 @@ def native_sparse_attend(
12631261
return_lse = False
12641262
):
12651263
seq_len = fq.shape[-2]
1264+
q_heads, kv_heads, sel_heads = fq.shape[1], fk.shape[1], selected_block_indices.shape[1]
1265+
1266+
assert divisible_by(q_heads, kv_heads)
1267+
assert sel_heads in (q_heads, kv_heads)
1268+
1269+
# query heads within each group to attend to different segments
1270+
1271+
if kv_heads != sel_heads:
1272+
fk, fv = tuple(repeat(t, 'b h ... -> b (h gh) ...', gh = q_heads // kv_heads) for t in (fk, fv))
12661273

12671274
out, lse = _native_sparse_attend(
12681275
fq, fk, fv,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.52"
3+
version = "0.0.53"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

test_triton_nsa.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def regular_attend(
2727
assert divisible_by(q_heads, kv_heads)
2828

2929
q, k, v = tuple(pad_to_multiple(t, block_size, dim = -2) for t in (q, k, v))
30-
indices, mask = tuple(pad_to_multiple(t, block_size, dim = -2) for t in (indices, mask))
3130

3231
g = q_heads // kv_heads # `g` stands for `g`roups of query heads per kv head
3332

@@ -52,6 +51,8 @@ def regular_attend(
5251
has_sel_kv_blocks = num_sel_kv_blocks > 0
5352

5453
if has_sel_kv_blocks:
54+
indices, mask = tuple(pad_to_multiple(t, block_size, dim = -2) for t in (indices, mask))
55+
5556
bk, bv = k, v
5657
sel_bk = einx.get_at('b h [w] n d, b h i sel -> b h i (sel n) d', bk, indices)
5758
sel_bv = einx.get_at('b h [w] n d, b h i sel -> b h i (sel n) d', bv, indices)
@@ -99,18 +100,19 @@ def regular_attend(
99100

100101
# mock inputs
101102

103+
batch = 2
102104
seq_len = 511
103105
q_heads = 4
104-
kv_heads = 2
106+
kv_heads = 4
105107
fine_block_size = 16
106-
num_sel = 1
108+
num_sel = 2
107109

108-
q = torch.randn(2, q_heads, seq_len, 64).cuda()
109-
k = torch.randn(2, kv_heads, seq_len, 64).cuda()
110-
v = torch.randn(2, kv_heads, seq_len, 64).cuda()
110+
q = torch.randn(batch, q_heads, seq_len, 64).cuda()
111+
k = torch.randn(batch, kv_heads, seq_len, 64).cuda()
112+
v = torch.randn(batch, kv_heads, seq_len, 64).cuda()
111113

112-
indices = torch.zeros(2, kv_heads, seq_len, num_sel).long().cuda()
113-
mask = torch.randint(0, 2, (2, kv_heads, seq_len, num_sel)).bool().cuda()
114+
indices = torch.zeros(batch, kv_heads, seq_len, num_sel).long().cuda()
115+
mask = torch.randint(0, 2, (batch, kv_heads, seq_len, num_sel)).bool().cuda()
114116

115117
# both regular and nsa pathways `r` and `n`
116118

train.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
USE_SPARSE_ATTN = True
3535
USE_TRITON_NSA = True
3636
USE_FLEX_FOR_FINE_SELECTION = False # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
37-
QUERY_HEADS_SHARE_SELECTION = False # if set to False, each query head can look at a different segment of their corresponding key / value head in GQA
37+
QUERY_HEADS_SHARE_SELECTION = True # if set to False, each query head can look at a different segment of their corresponding key / value head in GQA
3838

3939
# sparse attention related
4040

@@ -99,7 +99,11 @@ def base_decoding(
9999
sample_num_times = max(0, seq_len - prompt_seq_len)
100100

101101
for _ in tqdm(range(sample_num_times)):
102-
logits = net(out, disable_flex = True)
102+
logits = net(
103+
out,
104+
disable_flex = True,
105+
disable_triton_kernel = True
106+
)
103107

104108
logits = logits[:, -1]
105109
logits = top_k(logits, thres = filter_thres)

0 commit comments

Comments
 (0)