Skip to content

Commit 6f73ceb

Browse files
committed
oops, think they pair up the query heads to kv heads in gqa differently
1 parent 6eb45de commit 6f73ceb

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def forward(
318318

319319
# for gqa, we will average the compressed attention across each grouped queries (per key / values)
320320

321-
importance_scores = reduce(importance_scores, 'b (grouped_queries h) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
321+
importance_scores = reduce(importance_scores, 'b (h grouped_queries) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
322322

323323
# handle if compress block size does not equal to the fine block size
324324
# cannot parse their equation, so will just improvise
@@ -349,7 +349,7 @@ def forward(
349349
if exists(fine_selection_flex_mask):
350350
# flex attention for the selection for fine attention
351351

352-
fk, fv, selected_block_indices = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv, selected_block_indices))
352+
fk, fv, selected_block_indices = tuple(repeat(t, 'b h ... -> b (h num_grouped_queries) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv, selected_block_indices))
353353

354354
fine_block_mask = fine_selection_flex_mask(selected_block_indices)
355355

@@ -413,7 +413,7 @@ def forward(
413413

414414
# fine attention
415415

416-
fk, fv, fmask = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv, fmask))
416+
fk, fv, fmask = tuple(repeat(t, 'b h ... -> b (h num_grouped_queries) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv, fmask))
417417

418418
fsim = einsum(fq, fk, 'b h i d, b h i j d -> b h i j') * self.scale
419419

@@ -430,7 +430,7 @@ def forward(
430430
seq_len = fk.shape[-2]
431431
fmask = causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).tril()
432432

433-
fk, fv = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv))
433+
fk, fv = tuple(repeat(t, 'b h ... -> b (h num_grouped_queries) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv))
434434

435435
fsim = einsum(fq, fk, 'b h i d, b h j d -> b h i j') * self.scale
436436

@@ -449,7 +449,7 @@ def forward(
449449
if exists(sliding_window_flex_mask):
450450
sliding_window_attn_out = flex_attention(sq, sk, sv, block_mask = sliding_window_flex_mask, enable_gqa = True)
451451
else:
452-
sk, sv = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (sk, sv))
452+
sk, sv = tuple(repeat(t, 'b h ... -> b (h num_grouped_queries) ...', num_grouped_queries = self.num_grouped_queries) for t in (sk, sv))
453453

454454
sliding_window_attn_out = self.sliding_window(sq, sk, sv)
455455

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.30"
3+
version = "0.0.31"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)