22# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33
44import math
5+ import warnings
56from typing import Optional , Union
67
78import torch
@@ -1455,89 +1456,14 @@ def parallel_nsa_compression(
14551456
14561457
14571458def parallel_nsa (
1458- q : torch .Tensor ,
1459- k : torch .Tensor ,
1460- v : torch .Tensor ,
1461- g_slc : torch .Tensor ,
1462- g_swa : torch .Tensor ,
1463- block_indices : torch .LongTensor ,
1464- block_counts : Optional [Union [torch .LongTensor , int ]] = None ,
1465- block_size : int = 64 ,
1466- window_size : int = 0 ,
1467- scale : Optional [float ] = None ,
1468- cu_seqlens : Optional [torch .LongTensor ] = None ,
1469- head_first : bool = False
1470- ) -> torch .Tensor :
1471- r"""
1472- Args:
1473- q (torch.Tensor):
1474- queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
1475- k (torch.Tensor):
1476- keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
1477- GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
1478- v (torch.Tensor):
1479- values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1480- g_slc (torch.Tensor):
1481- Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1482- g_swa (torch.Tensor):
1483- Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1484- block_indices (torch.LongTensor):
1485- Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
1486- `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
1487- block_counts (Union[torch.LongTensor, int]):
1488- Number of selected blocks for each token.
1489- If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
1490- each token can select the same number of blocks.
1491- If not provided, it will default to `S`, Default: `None`
1492- block_size (int):
1493- Selected block size. Default: 64.
1494- window_size (int):
1495- Sliding window size. Default: 0.
1496- scale (Optional[int]):
1497- Scale factor for attention scores.
1498- If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1499- head_first (Optional[bool]):
1500- Whether the inputs are in the head-first format. Default: `False`.
1501- cu_seqlens (torch.LongTensor):
1502- Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1503- consistent with the FlashAttention API.
1504-
1505- Returns:
1506- o (torch.Tensor):
1507- Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
1508- """
1509- if scale is None :
1510- scale = k .shape [- 1 ] ** - 0.5
1511- if cu_seqlens is not None :
1512- assert q .shape [0 ] == 1 , "batch size must be 1 when cu_seqlens are provided"
1513- if head_first :
1514- q , k , v , block_indices = map (lambda x : rearrange (x , 'b h t d -> b t h d' ), (q , k , v , block_indices ))
1515- g_slc , g_swa = map (lambda x : rearrange (x , 'b h t -> b t h' ), (g_slc , g_swa ))
1516- if isinstance (block_counts , torch .Tensor ):
1517- block_counts = rearrange (block_counts , 'b h t -> b t h' )
1518- assert q .shape [2 ] % (k .shape [2 ] * 16 ) == 0 , "Group size must be a multiple of 16 in NSA"
1519-
1520- if isinstance (block_counts , int ):
1521- block_indices = block_indices [:, :, :, :block_counts ]
1522- block_counts = None
1523- o_slc , o_swa = ParallelNSAFunction .apply (q , k , v , block_indices , block_counts , block_size , window_size , scale , cu_seqlens )
1524- if window_size > 0 :
1525- o = torch .addcmul (o_slc * g_slc .unsqueeze (- 1 ), o_swa , g_swa .unsqueeze (- 1 ))
1526- else :
1527- o = o_slc * g_slc .unsqueeze (- 1 )
1528- if head_first :
1529- o = rearrange (o , 'b t h d -> b h t d' )
1530- return o
1531-
1532-
1533- def parallel_nsa_with_compression (
15341459 q : torch .Tensor ,
15351460 k : torch .Tensor ,
15361461 v : torch .Tensor ,
15371462 g_cmp : torch .Tensor ,
15381463 g_slc : torch .Tensor ,
15391464 g_swa : torch .Tensor ,
1540- block_counts : Union [torch .LongTensor , int ],
1465+ block_indices : Optional [torch .LongTensor ] = None ,
1466+ block_counts : Union [torch .LongTensor , int ] = 16 ,
15411467 block_size : int = 64 ,
15421468 window_size : int = 0 ,
15431469 scale : Optional [float ] = None ,
@@ -1559,10 +1485,15 @@ def parallel_nsa_with_compression(
15591485 Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
15601486 g_swa (torch.Tensor):
15611487 Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1488+ block_indices (torch.LongTensor):
1489+ Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
1490+ `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
1491+ If `g_cmp` is provided, the passed `block_indices` will be ignored.
15621492 block_counts (Optional[Union[torch.LongTensor, int]]):
15631493 Number of selected blocks for each query.
15641494 If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`,
15651495 each query can select the same number of blocks.
1496+ If not provided, it will default to 16.
15661497 block_size (int):
15671498 Selected block size. Default: 64.
15681499 window_size (int):
@@ -1587,7 +1518,7 @@ def parallel_nsa_with_compression(
15871518 assert q .shape [0 ] == 1 , "batch size must be 1 when cu_seqlens are provided"
15881519 if head_first :
15891520 q , k , v = map (lambda x : rearrange (x , 'b h t d -> b t h d' ), (q , k , v ))
1590- g_cmp , g_slc = map (lambda x : rearrange (x , 'b h t -> b t h' ), (g_cmp , g_slc ))
1521+ g_cmp , g_slc , g_swa = map (lambda x : rearrange (x , 'b h t -> b t h' ) if x is not None else None , (g_cmp , g_slc , g_swa ))
15911522 if not isinstance (block_counts , int ):
15921523 block_counts = rearrange (block_counts , 'b h t -> b t h' )
15931524 assert q .shape [2 ] % (k .shape [2 ] * 16 ) == 0 , "Group size must be a multiple of 16 in NSA"
@@ -1603,15 +1534,17 @@ def parallel_nsa_with_compression(
16031534 scale = scale ,
16041535 offsets = cu_seqlens
16051536 )
1606- block_indices = parallel_nsa_topk (
1607- q = q ,
1608- k = k_cmp ,
1609- lse = lse_cmp ,
1610- block_counts = block_counts ,
1611- block_size = block_size ,
1612- scale = scale ,
1613- offsets = cu_seqlens
1614- )
1537+ if block_indices is not None :
1538+ warnings .warn ("`block_indices` will be ignored when `g_cmp` is provided" )
1539+ block_indices = parallel_nsa_topk (
1540+ q = q ,
1541+ k = k_cmp ,
1542+ lse = lse_cmp ,
1543+ block_counts = block_counts ,
1544+ block_size = block_size ,
1545+ scale = scale ,
1546+ offsets = cu_seqlens
1547+ )
16151548 o_slc , o_swa = ParallelNSAFunction .apply (q , k , v , block_indices , block_counts , block_size , window_size , scale , cu_seqlens )
16161549 o = o_slc * g_slc .unsqueeze (- 1 )
16171550 if o_cmp is not None :
0 commit comments