Skip to content

Commit 75d9d14

Browse files
committed
Merge parallel_nsa with parallel_nsa_with_compression
1 parent dc9c539 commit 75d9d14

File tree

1 file changed

+20
-87
lines changed

1 file changed

+20
-87
lines changed

native_sparse_attention/ops/parallel.py

Lines changed: 20 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

44
import math
5+
import warnings
56
from typing import Optional, Union
67

78
import torch
@@ -1455,89 +1456,14 @@ def parallel_nsa_compression(
14551456

14561457

14571458
def parallel_nsa(
1458-
q: torch.Tensor,
1459-
k: torch.Tensor,
1460-
v: torch.Tensor,
1461-
g_slc: torch.Tensor,
1462-
g_swa: torch.Tensor,
1463-
block_indices: torch.LongTensor,
1464-
block_counts: Optional[Union[torch.LongTensor, int]] = None,
1465-
block_size: int = 64,
1466-
window_size: int = 0,
1467-
scale: Optional[float] = None,
1468-
cu_seqlens: Optional[torch.LongTensor] = None,
1469-
head_first: bool = False
1470-
) -> torch.Tensor:
1471-
r"""
1472-
Args:
1473-
q (torch.Tensor):
1474-
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
1475-
k (torch.Tensor):
1476-
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
1477-
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
1478-
v (torch.Tensor):
1479-
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1480-
g_slc (torch.Tensor):
1481-
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1482-
g_swa (torch.Tensor):
1483-
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1484-
block_indices (torch.LongTensor):
1485-
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
1486-
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
1487-
block_counts (Union[torch.LongTensor, int]):
1488-
Number of selected blocks for each token.
1489-
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
1490-
each token can select the same number of blocks.
1491-
If not provided, it will default to `S`, Default: `None`
1492-
block_size (int):
1493-
Selected block size. Default: 64.
1494-
window_size (int):
1495-
Sliding window size. Default: 0.
1496-
scale (Optional[int]):
1497-
Scale factor for attention scores.
1498-
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1499-
head_first (Optional[bool]):
1500-
Whether the inputs are in the head-first format. Default: `False`.
1501-
cu_seqlens (torch.LongTensor):
1502-
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1503-
consistent with the FlashAttention API.
1504-
1505-
Returns:
1506-
o (torch.Tensor):
1507-
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
1508-
"""
1509-
if scale is None:
1510-
scale = k.shape[-1] ** -0.5
1511-
if cu_seqlens is not None:
1512-
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
1513-
if head_first:
1514-
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v, block_indices))
1515-
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
1516-
if isinstance(block_counts, torch.Tensor):
1517-
block_counts = rearrange(block_counts, 'b h t -> b t h')
1518-
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
1519-
1520-
if isinstance(block_counts, int):
1521-
block_indices = block_indices[:, :, :, :block_counts]
1522-
block_counts = None
1523-
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
1524-
if window_size > 0:
1525-
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
1526-
else:
1527-
o = o_slc * g_slc.unsqueeze(-1)
1528-
if head_first:
1529-
o = rearrange(o, 'b t h d -> b h t d')
1530-
return o
1531-
1532-
1533-
def parallel_nsa_with_compression(
15341459
q: torch.Tensor,
15351460
k: torch.Tensor,
15361461
v: torch.Tensor,
15371462
g_cmp: torch.Tensor,
15381463
g_slc: torch.Tensor,
15391464
g_swa: torch.Tensor,
1540-
block_counts: Union[torch.LongTensor, int],
1465+
block_indices: Optional[torch.LongTensor] = None,
1466+
block_counts: Union[torch.LongTensor, int] = 16,
15411467
block_size: int = 64,
15421468
window_size: int = 0,
15431469
scale: Optional[float] = None,
@@ -1559,10 +1485,15 @@ def parallel_nsa_with_compression(
15591485
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
15601486
g_swa (torch.Tensor):
15611487
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1488+
block_indices (torch.LongTensor):
1489+
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
1490+
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
1491+
If `g_cmp` is provided, the passed `block_indices` will be ignored.
15621492
block_counts (Optional[Union[torch.LongTensor, int]]):
15631493
Number of selected blocks for each query.
15641494
If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`,
15651495
each query can select the same number of blocks.
1496+
If not provided, it will default to 16.
15661497
block_size (int):
15671498
Selected block size. Default: 64.
15681499
window_size (int):
@@ -1587,7 +1518,7 @@ def parallel_nsa_with_compression(
15871518
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
15881519
if head_first:
15891520
q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
1590-
g_cmp, g_slc = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_cmp, g_slc))
1521+
g_cmp, g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h') if x is not None else None, (g_cmp, g_slc, g_swa))
15911522
if not isinstance(block_counts, int):
15921523
block_counts = rearrange(block_counts, 'b h t -> b t h')
15931524
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
@@ -1603,15 +1534,17 @@ def parallel_nsa_with_compression(
16031534
scale=scale,
16041535
offsets=cu_seqlens
16051536
)
1606-
block_indices = parallel_nsa_topk(
1607-
q=q,
1608-
k=k_cmp,
1609-
lse=lse_cmp,
1610-
block_counts=block_counts,
1611-
block_size=block_size,
1612-
scale=scale,
1613-
offsets=cu_seqlens
1614-
)
1537+
if block_indices is not None:
1538+
warnings.warn("`block_indices` will be ignored when `g_cmp` is provided")
1539+
block_indices = parallel_nsa_topk(
1540+
q=q,
1541+
k=k_cmp,
1542+
lse=lse_cmp,
1543+
block_counts=block_counts,
1544+
block_size=block_size,
1545+
scale=scale,
1546+
offsets=cu_seqlens
1547+
)
16151548
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
16161549
o = o_slc * g_slc.unsqueeze(-1)
16171550
if o_cmp is not None:

0 commit comments

Comments
 (0)