22# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33
44import math
5+ import warnings
56from typing import Optional , Union
67
78import torch
1617from fla .utils import autocast_custom_bwd , autocast_custom_fwd , contiguous
1718from native_sparse_attention .ops .utils import _bitonic_merge
1819
20+ try :
21+ from flash_attn import flash_attn_func , flash_attn_varlen_func
22+ except ImportError :
23+ warnings .warn (
24+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`" ,
25+ category = ImportWarning
26+ )
27+ flash_attn_func = None
28+
1929
2030@triton .heuristics ({
2131 'USE_OFFSETS' : lambda args : args ['offsets' ] is not None
@@ -1455,89 +1465,14 @@ def parallel_nsa_compression(
14551465
14561466
14571467def parallel_nsa (
1458- q : torch .Tensor ,
1459- k : torch .Tensor ,
1460- v : torch .Tensor ,
1461- g_slc : torch .Tensor ,
1462- g_swa : torch .Tensor ,
1463- block_indices : torch .LongTensor ,
1464- block_counts : Optional [Union [torch .LongTensor , int ]] = None ,
1465- block_size : int = 64 ,
1466- window_size : int = 0 ,
1467- scale : Optional [float ] = None ,
1468- cu_seqlens : Optional [torch .LongTensor ] = None ,
1469- head_first : bool = False
1470- ) -> torch .Tensor :
1471- r"""
1472- Args:
1473- q (torch.Tensor):
1474- queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
1475- k (torch.Tensor):
1476- keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
1477- GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
1478- v (torch.Tensor):
1479- values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1480- g_slc (torch.Tensor):
1481- Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1482- g_swa (torch.Tensor):
1483- Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1484- block_indices (torch.LongTensor):
1485- Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
1486- `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
1487- block_counts (Union[torch.LongTensor, int]):
1488- Number of selected blocks for each token.
1489- If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
1490- each token can select the same number of blocks.
1491- If not provided, it will default to `S`, Default: `None`
1492- block_size (int):
1493- Selected block size. Default: 64.
1494- window_size (int):
1495- Sliding window size. Default: 0.
1496- scale (Optional[int]):
1497- Scale factor for attention scores.
1498- If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1499- head_first (Optional[bool]):
1500- Whether the inputs are in the head-first format. Default: `False`.
1501- cu_seqlens (torch.LongTensor):
1502- Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1503- consistent with the FlashAttention API.
1504-
1505- Returns:
1506- o (torch.Tensor):
1507- Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
1508- """
1509- if scale is None :
1510- scale = k .shape [- 1 ] ** - 0.5
1511- if cu_seqlens is not None :
1512- assert q .shape [0 ] == 1 , "batch size must be 1 when cu_seqlens are provided"
1513- if head_first :
1514- q , k , v , block_indices = map (lambda x : rearrange (x , 'b h t d -> b t h d' ), (q , k , v , block_indices ))
1515- g_slc , g_swa = map (lambda x : rearrange (x , 'b h t -> b t h' ), (g_slc , g_swa ))
1516- if isinstance (block_counts , torch .Tensor ):
1517- block_counts = rearrange (block_counts , 'b h t -> b t h' )
1518- assert q .shape [2 ] % (k .shape [2 ] * 16 ) == 0 , "Group size must be a multiple of 16 in NSA"
1519-
1520- if isinstance (block_counts , int ):
1521- block_indices = block_indices [:, :, :, :block_counts ]
1522- block_counts = None
1523- o_slc , o_swa = ParallelNSAFunction .apply (q , k , v , block_indices , block_counts , block_size , window_size , scale , cu_seqlens )
1524- if window_size > 0 :
1525- o = torch .addcmul (o_slc * g_slc .unsqueeze (- 1 ), o_swa , g_swa .unsqueeze (- 1 ))
1526- else :
1527- o = o_slc * g_slc .unsqueeze (- 1 )
1528- if head_first :
1529- o = rearrange (o , 'b t h d -> b h t d' )
1530- return o
1531-
1532-
1533- def parallel_nsa_with_compression (
15341468 q : torch .Tensor ,
15351469 k : torch .Tensor ,
15361470 v : torch .Tensor ,
15371471 g_cmp : torch .Tensor ,
15381472 g_slc : torch .Tensor ,
15391473 g_swa : torch .Tensor ,
1540- block_counts : Union [torch .LongTensor , int ],
1474+ block_indices : Optional [torch .LongTensor ] = None ,
1475+ block_counts : Union [torch .LongTensor , int ] = 16 ,
15411476 block_size : int = 64 ,
15421477 window_size : int = 0 ,
15431478 scale : Optional [float ] = None ,
@@ -1559,10 +1494,15 @@ def parallel_nsa_with_compression(
15591494 Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
15601495 g_swa (torch.Tensor):
15611496 Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1497+ block_indices (torch.LongTensor):
1498+ Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
1499+ `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
1500+ If `g_cmp` is provided, the passed `block_indices` will be ignored.
15621501 block_counts (Optional[Union[torch.LongTensor, int]]):
15631502 Number of selected blocks for each query.
15641503 If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`,
15651504 each query can select the same number of blocks.
1505+ If not provided, it will default to 16.
15661506 block_size (int):
15671507 Selected block size. Default: 64.
15681508 window_size (int):
@@ -1587,7 +1527,7 @@ def parallel_nsa_with_compression(
15871527 assert q .shape [0 ] == 1 , "batch size must be 1 when cu_seqlens are provided"
15881528 if head_first :
15891529 q , k , v = map (lambda x : rearrange (x , 'b h t d -> b t h d' ), (q , k , v ))
1590- g_cmp , g_slc = map (lambda x : rearrange (x , 'b h t -> b t h' ), (g_cmp , g_slc ))
1530+ g_cmp , g_slc , g_swa = map (lambda x : rearrange (x , 'b h t -> b t h' ) if x is not None else None , (g_cmp , g_slc , g_swa ))
15911531 if not isinstance (block_counts , int ):
15921532 block_counts = rearrange (block_counts , 'b h t -> b t h' )
15931533 assert q .shape [2 ] % (k .shape [2 ] * 16 ) == 0 , "Group size must be a multiple of 16 in NSA"
@@ -1603,20 +1543,39 @@ def parallel_nsa_with_compression(
16031543 scale = scale ,
16041544 offsets = cu_seqlens
16051545 )
1606- block_indices = parallel_nsa_topk (
1607- q = q ,
1608- k = k_cmp ,
1609- lse = lse_cmp ,
1610- block_counts = block_counts ,
1611- block_size = block_size ,
1612- scale = scale ,
1613- offsets = cu_seqlens
1614- )
1615- o_slc , o_swa = ParallelNSAFunction .apply (q , k , v , block_indices , block_counts , block_size , window_size , scale , cu_seqlens )
1546+ if block_indices is not None :
1547+ warnings .warn ("`block_indices` will be ignored when `g_cmp` is provided" )
1548+ block_indices = parallel_nsa_topk (
1549+ q = q ,
1550+ k = k_cmp ,
1551+ lse = lse_cmp ,
1552+ block_counts = block_counts ,
1553+ block_size = block_size ,
1554+ scale = scale ,
1555+ offsets = cu_seqlens
1556+ )
1557+ o_slc , o_swa = ParallelNSAFunction .apply (q , k , v , block_indices , block_counts , block_size , 0 , scale , cu_seqlens )
16161558 o = o_slc * g_slc .unsqueeze (- 1 )
16171559 if o_cmp is not None :
16181560 o = torch .addcmul (o , o_cmp , g_cmp .unsqueeze (- 1 ))
16191561 if window_size > 0 :
1562+ if cu_seqlens is not None :
1563+ max_seqlen = prepare_lens (cu_seqlens )
1564+ o = flash_attn_varlen_func (
1565+ q .squeeze (0 ), k .squeeze (0 ), v .squeeze (0 ),
1566+ cu_seqlens_q = cu_seqlens ,
1567+ cu_seqlens_k = cu_seqlens ,
1568+ max_seqlen_q = max_seqlen ,
1569+ max_seqlen_k = max_seqlen ,
1570+ causal = True ,
1571+ window_size = (window_size - 1 , 0 )
1572+ ).unsqueeze (0 )
1573+ else :
1574+ o = flash_attn_func (
1575+ q , k , v ,
1576+ causal = True ,
1577+ window_size = (window_size - 1 , 0 )
1578+ )
16201579 o = torch .addcmul (o , o_swa , g_swa .unsqueeze (- 1 ))
16211580 if head_first :
16221581 o = rearrange (o , 'b t h d -> b h t d' )
0 commit comments