Skip to content

Commit ed321e4

Browse files
committed
Merge parallel_nsa with parallel_nsa_with_compression
1 parent dc9c539 commit ed321e4

File tree

4 files changed

+52
-97
lines changed

4 files changed

+52
-97
lines changed

native_sparse_attention/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
from native_sparse_attention.configuration_nsa import NSAConfig
66
from native_sparse_attention.modeling_nsa import NSAForCausalLM, NSAModel
7-
from native_sparse_attention.ops.parallel import (
8-
parallel_nsa, parallel_nsa_with_compression)
7+
from native_sparse_attention.ops.parallel import parallel_nsa
98

109
AutoConfig.register(NSAConfig.model_type, NSAConfig)
1110
AutoModel.register(NSAConfig, NSAModel)
@@ -15,7 +14,6 @@
1514
__all__ = [
1615
'NSAConfig', 'NSAModel', 'NSAForCausalLM',
1716
'parallel_nsa',
18-
'parallel_nsa_with_compression'
1917
]
2018

2119

native_sparse_attention/modeling_nsa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from fla.modules import GatedMLP as NSAMLP
2323
from fla.modules import RMSNorm, RotaryEmbedding
2424
from native_sparse_attention.configuration_nsa import NSAConfig
25-
from native_sparse_attention.ops.parallel import parallel_nsa_with_compression
25+
from native_sparse_attention.ops.parallel import parallel_nsa
2626

2727
if TYPE_CHECKING:
2828
from transformers.processing_utils import Unpack
@@ -128,7 +128,7 @@ def forward(
128128
k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
129129
v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
130130

131-
o, _ = parallel_nsa_with_compression(
131+
o = parallel_nsa(
132132
q=q,
133133
k=k,
134134
v=v,
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# -*- coding: utf-8 -*-
22

3-
from .naive import naive_nsa, naive_nsa_with_compression
4-
from .parallel import parallel_nsa, parallel_nsa_with_compression
3+
from .naive import naive_nsa
4+
from .parallel import parallel_nsa
55

66
__all__ = [
77
'naive_nsa',
88
'parallel_nsa',
9-
'naive_nsa_with_compression',
10-
'parallel_nsa_with_compression'
119
]

native_sparse_attention/ops/parallel.py

Lines changed: 47 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

44
import math
5+
import warnings
56
from typing import Optional, Union
67

78
import torch
@@ -16,6 +17,15 @@
1617
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
1718
from native_sparse_attention.ops.utils import _bitonic_merge
1819

20+
try:
21+
from flash_attn import flash_attn_func, flash_attn_varlen_func
22+
except ImportError:
23+
warnings.warn(
24+
"Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
25+
category=ImportWarning
26+
)
27+
flash_attn_func = None
28+
1929

2030
@triton.heuristics({
2131
'USE_OFFSETS': lambda args: args['offsets'] is not None
@@ -1455,89 +1465,14 @@ def parallel_nsa_compression(
14551465

14561466

14571467
def parallel_nsa(
1458-
q: torch.Tensor,
1459-
k: torch.Tensor,
1460-
v: torch.Tensor,
1461-
g_slc: torch.Tensor,
1462-
g_swa: torch.Tensor,
1463-
block_indices: torch.LongTensor,
1464-
block_counts: Optional[Union[torch.LongTensor, int]] = None,
1465-
block_size: int = 64,
1466-
window_size: int = 0,
1467-
scale: Optional[float] = None,
1468-
cu_seqlens: Optional[torch.LongTensor] = None,
1469-
head_first: bool = False
1470-
) -> torch.Tensor:
1471-
r"""
1472-
Args:
1473-
q (torch.Tensor):
1474-
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
1475-
k (torch.Tensor):
1476-
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
1477-
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
1478-
v (torch.Tensor):
1479-
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1480-
g_slc (torch.Tensor):
1481-
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1482-
g_swa (torch.Tensor):
1483-
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1484-
block_indices (torch.LongTensor):
1485-
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
1486-
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
1487-
block_counts (Union[torch.LongTensor, int]):
1488-
Number of selected blocks for each token.
1489-
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
1490-
each token can select the same number of blocks.
1491-
If not provided, it will default to `S`, Default: `None`
1492-
block_size (int):
1493-
Selected block size. Default: 64.
1494-
window_size (int):
1495-
Sliding window size. Default: 0.
1496-
scale (Optional[int]):
1497-
Scale factor for attention scores.
1498-
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1499-
head_first (Optional[bool]):
1500-
Whether the inputs are in the head-first format. Default: `False`.
1501-
cu_seqlens (torch.LongTensor):
1502-
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1503-
consistent with the FlashAttention API.
1504-
1505-
Returns:
1506-
o (torch.Tensor):
1507-
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
1508-
"""
1509-
if scale is None:
1510-
scale = k.shape[-1] ** -0.5
1511-
if cu_seqlens is not None:
1512-
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
1513-
if head_first:
1514-
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v, block_indices))
1515-
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
1516-
if isinstance(block_counts, torch.Tensor):
1517-
block_counts = rearrange(block_counts, 'b h t -> b t h')
1518-
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
1519-
1520-
if isinstance(block_counts, int):
1521-
block_indices = block_indices[:, :, :, :block_counts]
1522-
block_counts = None
1523-
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
1524-
if window_size > 0:
1525-
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
1526-
else:
1527-
o = o_slc * g_slc.unsqueeze(-1)
1528-
if head_first:
1529-
o = rearrange(o, 'b t h d -> b h t d')
1530-
return o
1531-
1532-
1533-
def parallel_nsa_with_compression(
15341468
q: torch.Tensor,
15351469
k: torch.Tensor,
15361470
v: torch.Tensor,
15371471
g_cmp: torch.Tensor,
15381472
g_slc: torch.Tensor,
15391473
g_swa: torch.Tensor,
1540-
block_counts: Union[torch.LongTensor, int],
1474+
block_indices: Optional[torch.LongTensor] = None,
1475+
block_counts: Union[torch.LongTensor, int] = 16,
15411476
block_size: int = 64,
15421477
window_size: int = 0,
15431478
scale: Optional[float] = None,
@@ -1559,10 +1494,15 @@ def parallel_nsa_with_compression(
15591494
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
15601495
g_swa (torch.Tensor):
15611496
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1497+
block_indices (torch.LongTensor):
1498+
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
1499+
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
1500+
If `g_cmp` is provided, the passed `block_indices` will be ignored.
15621501
block_counts (Optional[Union[torch.LongTensor, int]]):
15631502
Number of selected blocks for each query.
15641503
If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`,
15651504
each query can select the same number of blocks.
1505+
If not provided, it will default to 16.
15661506
block_size (int):
15671507
Selected block size. Default: 64.
15681508
window_size (int):
@@ -1587,7 +1527,7 @@ def parallel_nsa_with_compression(
15871527
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
15881528
if head_first:
15891529
q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
1590-
g_cmp, g_slc = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_cmp, g_slc))
1530+
g_cmp, g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h') if x is not None else None, (g_cmp, g_slc, g_swa))
15911531
if not isinstance(block_counts, int):
15921532
block_counts = rearrange(block_counts, 'b h t -> b t h')
15931533
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
@@ -1603,20 +1543,39 @@ def parallel_nsa_with_compression(
16031543
scale=scale,
16041544
offsets=cu_seqlens
16051545
)
1606-
block_indices = parallel_nsa_topk(
1607-
q=q,
1608-
k=k_cmp,
1609-
lse=lse_cmp,
1610-
block_counts=block_counts,
1611-
block_size=block_size,
1612-
scale=scale,
1613-
offsets=cu_seqlens
1614-
)
1615-
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
1546+
if block_indices is not None:
1547+
warnings.warn("`block_indices` will be ignored when `g_cmp` is provided")
1548+
block_indices = parallel_nsa_topk(
1549+
q=q,
1550+
k=k_cmp,
1551+
lse=lse_cmp,
1552+
block_counts=block_counts,
1553+
block_size=block_size,
1554+
scale=scale,
1555+
offsets=cu_seqlens
1556+
)
1557+
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, 0, scale, cu_seqlens)
16161558
o = o_slc * g_slc.unsqueeze(-1)
16171559
if o_cmp is not None:
16181560
o = torch.addcmul(o, o_cmp, g_cmp.unsqueeze(-1))
16191561
if window_size > 0:
1562+
if cu_seqlens is not None:
1563+
max_seqlen = prepare_lens(cu_seqlens)
1564+
o = flash_attn_varlen_func(
1565+
q.squeeze(0), k.squeeze(0), v.squeeze(0),
1566+
cu_seqlens_q=cu_seqlens,
1567+
cu_seqlens_k=cu_seqlens,
1568+
max_seqlen_q=max_seqlen,
1569+
max_seqlen_k=max_seqlen,
1570+
causal=True,
1571+
window_size=(window_size-1, 0)
1572+
).unsqueeze(0)
1573+
else:
1574+
o = flash_attn_func(
1575+
q, k, v,
1576+
causal=True,
1577+
window_size=(window_size-1, 0)
1578+
)
16201579
o = torch.addcmul(o, o_swa, g_swa.unsqueeze(-1))
16211580
if head_first:
16221581
o = rearrange(o, 'b t h d -> b h t d')

0 commit comments

Comments
 (0)