|
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang |
3 | 3 |
|
4 | | -import math |
5 | 4 | import warnings |
6 | 5 | from typing import Optional, Union |
7 | 6 |
|
8 | 7 | import torch |
9 | | -import torch.nn.functional as F |
10 | 8 | import triton |
11 | 9 | import triton.language as tl |
12 | 10 | import triton.language.core as core |
13 | 11 | from einops import rearrange |
14 | 12 |
|
15 | 13 | from fla.ops.common.utils import (prepare_chunk_indices, prepare_chunk_offsets, |
16 | 14 | prepare_lens, prepare_token_indices) |
| 15 | +from fla.ops.utils import mean_pooling |
17 | 16 | from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous |
18 | 17 | from native_sparse_attention.ops.utils import _bitonic_merge |
19 | 18 |
|
@@ -812,23 +811,6 @@ def parallel_nsa_bwd_kernel_dkv( |
812 | 811 | tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) |
813 | 812 |
|
814 | 813 |
|
815 | | -@torch.compile |
816 | | -def compression( |
817 | | - k: torch.Tensor, |
818 | | - v: torch.Tensor, |
819 | | - block_size: int |
820 | | -) -> torch.Tensor: |
821 | | - # Currently, we set mean pooling as our basic compression function. |
822 | | - B, T, H = k.shape[:3] |
823 | | - num_block = math.ceil(T / block_size) |
824 | | - if k.shape[1] % block_size != 0: |
825 | | - k = F.pad(k, (0, 0, 0, 0, 0, num_block * block_size - T)) |
826 | | - v = F.pad(v, (0, 0, 0, 0, 0, num_block * block_size - T)) |
827 | | - k_cmp = k.view(B, num_block, block_size, H, -1).mean(dim=2) |
828 | | - v_cmp = v.view(B, num_block, block_size, H, -1).mean(dim=2) |
829 | | - return k_cmp, v_cmp |
830 | | - |
831 | | - |
832 | 814 | def parallel_nsa_compression_fwd( |
833 | 815 | q: torch.Tensor, |
834 | 816 | k: torch.Tensor, |
@@ -1411,7 +1393,7 @@ def parallel_nsa( |
1411 | 1393 | block_counts = rearrange(block_counts, 'b h t -> b t h') |
1412 | 1394 | assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" |
1413 | 1395 |
|
1414 | | - k_cmp, v_cmp = compression(k, v, block_size) |
| 1396 | + k_cmp, v_cmp = mean_pooling(k, block_size, cu_seqlens), mean_pooling(v, block_size, cu_seqlens) |
1415 | 1397 | o_cmp, lse_cmp = None, None |
1416 | 1398 | if g_cmp is not None: |
1417 | 1399 | o_cmp, lse_cmp = parallel_nsa_compression( |
|
0 commit comments