Skip to content

Commit a60d291

Browse files
committed
Support varlen mean pooling compression
1 parent 7f43d53 commit a60d291

File tree

4 files changed

+22
-24
lines changed

4 files changed

+22
-24
lines changed

native_sparse_attention/modeling_nsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def forward(
9898
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3)
9999
g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
100100

101-
# equivalent to cu_seqlens in `flash_attn`
102101
cu_seqlens = kwargs.get('cu_seqlens', None)
103102

104103
seqlen_offset, max_seqlen = 0, seq_len
@@ -138,6 +137,7 @@ def forward(
138137
block_size=self.block_size,
139138
block_counts=self.block_counts,
140139
window_size=self.window_size,
140+
cu_seqlens=cu_seqlens,
141141
head_first=False
142142
)
143143
o = o.reshape(batch_size, seq_len, -1)

native_sparse_attention/ops/naive.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,25 @@
55
from typing import Optional, Union
66

77
import torch
8+
import torch.nn.functional as F
89
from einops import rearrange, repeat
910

10-
from native_sparse_attention.ops.parallel import compression
11+
12+
@torch.compile
13+
def compression(
14+
k: torch.Tensor,
15+
v: torch.Tensor,
16+
block_size: int
17+
) -> torch.Tensor:
18+
# Currently, we set mean pooling as our basic compression function.
19+
B, T, H = k.shape[:3]
20+
num_block = math.ceil(T / block_size)
21+
if k.shape[1] % block_size != 0:
22+
k = F.pad(k, (0, 0, 0, 0, 0, num_block * block_size - T))
23+
v = F.pad(v, (0, 0, 0, 0, 0, num_block * block_size - T))
24+
k_cmp = k.view(B, num_block, block_size, H, -1).mean(dim=2)
25+
v_cmp = v.view(B, num_block, block_size, H, -1).mean(dim=2)
26+
return k_cmp, v_cmp
1127

1228

1329
def naive_nsa(

native_sparse_attention/ops/parallel.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import math
54
import warnings
65
from typing import Optional, Union
76

87
import torch
9-
import torch.nn.functional as F
108
import triton
119
import triton.language as tl
1210
import triton.language.core as core
1311
from einops import rearrange
1412

1513
from fla.ops.common.utils import (prepare_chunk_indices, prepare_chunk_offsets,
1614
prepare_lens, prepare_token_indices)
15+
from fla.ops.utils import mean_pooling
1716
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
1817
from native_sparse_attention.ops.utils import _bitonic_merge
1918

@@ -812,23 +811,6 @@ def parallel_nsa_bwd_kernel_dkv(
812811
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
813812

814813

815-
@torch.compile
816-
def compression(
817-
k: torch.Tensor,
818-
v: torch.Tensor,
819-
block_size: int
820-
) -> torch.Tensor:
821-
# Currently, we set mean pooling as our basic compression function.
822-
B, T, H = k.shape[:3]
823-
num_block = math.ceil(T / block_size)
824-
if k.shape[1] % block_size != 0:
825-
k = F.pad(k, (0, 0, 0, 0, 0, num_block * block_size - T))
826-
v = F.pad(v, (0, 0, 0, 0, 0, num_block * block_size - T))
827-
k_cmp = k.view(B, num_block, block_size, H, -1).mean(dim=2)
828-
v_cmp = v.view(B, num_block, block_size, H, -1).mean(dim=2)
829-
return k_cmp, v_cmp
830-
831-
832814
def parallel_nsa_compression_fwd(
833815
q: torch.Tensor,
834816
k: torch.Tensor,
@@ -1411,7 +1393,7 @@ def parallel_nsa(
14111393
block_counts = rearrange(block_counts, 'b h t -> b t h')
14121394
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
14131395

1414-
k_cmp, v_cmp = compression(k, v, block_size)
1396+
k_cmp, v_cmp = mean_pooling(k, block_size, cu_seqlens), mean_pooling(v, block_size, cu_seqlens)
14151397
o_cmp, lse_cmp = None, None
14161398
if g_cmp is not None:
14171399
o_cmp, lse_cmp = parallel_nsa_compression(
@@ -1439,7 +1421,7 @@ def parallel_nsa(
14391421
o = torch.addcmul(o, o_cmp, g_cmp.unsqueeze(-1))
14401422
if window_size > 0:
14411423
if cu_seqlens is not None:
1442-
max_seqlen = prepare_lens(cu_seqlens)
1424+
max_seqlen = q.shape[1]
14431425
o_swa = flash_attn_varlen_func(
14441426
q.squeeze(0), k.squeeze(0), v.squeeze(0),
14451427
cu_seqlens_q=cu_seqlens,

0 commit comments

Comments
 (0)