|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +from dataclasses import dataclass |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch import nn |
| 8 | + |
| 9 | +from kvpress.presses.base_press import BasePress |
| 10 | +from kvpress.presses.scorer_press import ScorerPress |
| 11 | + |
| 12 | + |
| 13 | +@dataclass |
| 14 | +class ChunkKVPress(BasePress): |
| 15 | + """ |
| 16 | + Wrapper class for any ScorerPress. |
| 17 | + First calculates global scores using the ScorerPress, |
| 18 | + then selects tokens chunk by chunk based on these global scores. |
| 19 | + This method was proposed in |
| 20 | + ChunkKV: Semantic-Preserving KV Cache Compression for Efficient Long-Context LLM Inference |
| 21 | + https://arxiv.org/abs/2502.00299 |
| 22 | + """ |
| 23 | + |
| 24 | + press: ScorerPress |
| 25 | + chunk_length: int = 20 |
| 26 | + |
| 27 | + def __post_init__(self): |
| 28 | + assert isinstance(self.press, ScorerPress), "ChunkKVPress requires a ScorerPress as input" |
| 29 | + |
| 30 | + @property |
| 31 | + def compression_ratio(self): |
| 32 | + return self.press.compression_ratio |
| 33 | + |
| 34 | + @compression_ratio.setter |
| 35 | + def compression_ratio(self, value): |
| 36 | + self.press.compression_ratio = value |
| 37 | + |
| 38 | + def compress( |
| 39 | + self, |
| 40 | + module: nn.Module, |
| 41 | + hidden_states: torch.Tensor, |
| 42 | + keys: torch.Tensor, |
| 43 | + values: torch.Tensor, |
| 44 | + attentions: torch.Tensor, |
| 45 | + kwargs: dict, |
| 46 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 47 | + |
| 48 | + if self.press.compression_ratio == 0: |
| 49 | + return keys, values |
| 50 | + |
| 51 | + assert attentions is None, "ChunkPress does not support attentions." |
| 52 | + |
| 53 | + kv_len = keys.shape[2] |
| 54 | + |
| 55 | + # 1. Calculate global scores first |
| 56 | + global_scores = self.press.score( |
| 57 | + module, |
| 58 | + hidden_states, |
| 59 | + keys, |
| 60 | + values, |
| 61 | + attentions, |
| 62 | + kwargs, |
| 63 | + ) |
| 64 | + |
| 65 | + # 2. Calculate actual number of complete chunks and remaining tokens |
| 66 | + num_complete_chunks = kv_len // self.chunk_length |
| 67 | + remaining_tokens = kv_len % self.chunk_length |
| 68 | + |
| 69 | + # If we have no complete chunks, delegate to the underlying scorer press |
| 70 | + if num_complete_chunks == 0: |
| 71 | + return self.press.compress(module, hidden_states, keys, values, attentions, kwargs) |
| 72 | + |
| 73 | + # Reshape complete chunks for score calculation |
| 74 | + if num_complete_chunks > 0: |
| 75 | + main_scores = global_scores[..., : num_complete_chunks * self.chunk_length] |
| 76 | + main_chunk_scores = main_scores.sum(dim=1).view(-1, num_complete_chunks, self.chunk_length) |
| 77 | + main_chunk_scores = main_chunk_scores.mean(dim=-1) |
| 78 | + else: |
| 79 | + main_chunk_scores = torch.empty((global_scores.shape[0], 0), device=global_scores.device) |
| 80 | + |
| 81 | + # Handle remaining tokens if any |
| 82 | + if remaining_tokens > 0: |
| 83 | + remaining_scores = global_scores[..., -remaining_tokens:] |
| 84 | + remaining_chunk_score = remaining_scores.sum(dim=1).mean(dim=-1, keepdim=True) |
| 85 | + chunk_scores = torch.cat([main_chunk_scores, remaining_chunk_score], dim=-1) |
| 86 | + else: |
| 87 | + chunk_scores = main_chunk_scores |
| 88 | + |
| 89 | + # 3. Calculate number of chunks to keep |
| 90 | + n_chunks_kept = max(1, int((num_complete_chunks + (remaining_tokens > 0)) * (1 - self.press.compression_ratio))) |
| 91 | + top_chunks = chunk_scores.topk(n_chunks_kept, dim=-1) |
| 92 | + |
| 93 | + # 4. Create indices for selected chunks |
| 94 | + indices = [] |
| 95 | + for chunk_idx in top_chunks.indices[0]: |
| 96 | + if chunk_idx < num_complete_chunks: |
| 97 | + # For complete chunks |
| 98 | + start_idx = chunk_idx * self.chunk_length |
| 99 | + chunk_indices = torch.arange(start_idx, start_idx + self.chunk_length, device=keys.device) |
| 100 | + else: |
| 101 | + # For the remaining partial chunk |
| 102 | + chunk_indices = torch.arange(num_complete_chunks * self.chunk_length, kv_len, device=keys.device) |
| 103 | + indices.append(chunk_indices) |
| 104 | + |
| 105 | + indices = torch.cat(indices).sort()[0] |
| 106 | + indices = indices.view(1, 1, -1, 1).expand(keys.shape[0], keys.shape[1], -1, module.head_dim) |
| 107 | + |
| 108 | + # 5. Use gather to collect selected keys and values |
| 109 | + keys = keys.gather(2, indices).contiguous() |
| 110 | + values = values.gather(2, indices).contiguous() |
| 111 | + |
| 112 | + return keys, values |
0 commit comments