Skip to content

Commit fd16d8b

Browse files
add ChunkKV
1 parent a94a78d commit fd16d8b

File tree

10 files changed

+150
-20
lines changed

10 files changed

+150
-20
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Finally we provide wrapper presses that can be combined with other presses:
7777
- `PerLayerCompressionPress` ([source](kvpress/presses/per_layer_compression_press.py)): compress each layer with a different compression ratio (experimental)
7878
- `ComposedPress` ([source](kvpress/presses/composed_press.py)): compose multiple presses together by chaining their forward hooks
7979
- `KeyRerotationPress` ([source](kvpress/presses/key_rerotation_press.py)): rerotate pruned keys to have continuous RoPE embeddings
80+
- `ChunkKVPress` ([source](kvpress/presses/chunkkv_press.py), [paper](https://arxiv.org/abs/2502.00299)): compresses by selecting important chunks, preserving semantic coherence
8081
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
8182
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.
8283

@@ -175,4 +176,4 @@ with press(model):
175176

176177
However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once.
177178

178-
</details>
179+
</details>

evaluation/evaluate.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer
1818

1919
from kvpress import (
20-
CriticalKVPress,
21-
CriticalAdaKVPress,
2220
AdaKVPress,
21+
ChunkKVPress,
22+
CriticalAdaKVPress,
23+
CriticalKVPress,
24+
DuoAttentionPress,
2325
ExpectedAttentionPress,
2426
KnormPress,
2527
ObservedAttentionPress,
@@ -28,7 +30,6 @@
2830
StreamingLLMPress,
2931
ThinKPress,
3032
TOVAPress,
31-
DuoAttentionPress,
3233
)
3334

3435
logger = logging.getLogger(__name__)
@@ -64,6 +65,7 @@
6465
"think": ThinKPress(),
6566
"tova": TOVAPress(),
6667
"duo_attention": DuoAttentionPress(),
68+
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
6769
}
6870

6971

kvpress/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from kvpress.presses.adakv_press import AdaKVPress
88
from kvpress.presses.base_press import BasePress
99
from kvpress.presses.chunk_press import ChunkPress
10+
from kvpress.presses.chunkkv_press import ChunkKVPress
1011
from kvpress.presses.composed_press import ComposedPress
12+
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
13+
from kvpress.presses.duo_attention_press import DuoAttentionPress
1114
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
1215
from kvpress.presses.key_rerotation_press import KeyRerotationPress
1316
from kvpress.presses.knorm_press import KnormPress
@@ -20,8 +23,6 @@
2023
from kvpress.presses.streaming_llm_press import StreamingLLMPress
2124
from kvpress.presses.think_press import ThinKPress
2225
from kvpress.presses.tova_press import TOVAPress
23-
from kvpress.presses.criticalkv_press import CriticalKVPress, CriticalAdaKVPress
24-
from kvpress.presses.duo_attention_press import DuoAttentionPress
2526

2627
# Patch the attention functions to support head-wise compression
2728
patch_attention_functions()
@@ -47,4 +48,5 @@
4748
"KeyRerotationPress",
4849
"ChunkPress",
4950
"DuoAttentionPress",
51+
"ChunkKVPress",
5052
]

kvpress/presses/chunk_press.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def compress(
5050
assert attentions is None, "ChunkPress does not support attentions."
5151

5252
kv_len = keys.shape[2]
53-
5453
indices = []
5554
for i in range(0, kv_len, self.chunk_length):
5655
chunk_scores = self.press.score(

kvpress/presses/chunkkv_press.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from dataclasses import dataclass
5+
6+
import torch
7+
from torch import nn
8+
9+
from kvpress.presses.base_press import BasePress
10+
from kvpress.presses.scorer_press import ScorerPress
11+
12+
13+
@dataclass
14+
class ChunkKVPress(BasePress):
15+
"""
16+
Wrapper class for any ScorerPress.
17+
First calculates global scores using the ScorerPress,
18+
then selects tokens chunk by chunk based on these global scores.
19+
This method was proposed in
20+
ChunkKV: Semantic-Preserving KV Cache Compression for Efficient Long-Context LLM Inference
21+
https://arxiv.org/abs/2502.00299
22+
"""
23+
24+
press: ScorerPress
25+
chunk_length: int = 20
26+
27+
def __post_init__(self):
28+
assert isinstance(self.press, ScorerPress), "ChunkKVPress requires a ScorerPress as input"
29+
30+
@property
31+
def compression_ratio(self):
32+
return self.press.compression_ratio
33+
34+
@compression_ratio.setter
35+
def compression_ratio(self, value):
36+
self.press.compression_ratio = value
37+
38+
def compress(
39+
self,
40+
module: nn.Module,
41+
hidden_states: torch.Tensor,
42+
keys: torch.Tensor,
43+
values: torch.Tensor,
44+
attentions: torch.Tensor,
45+
kwargs: dict,
46+
) -> tuple[torch.Tensor, torch.Tensor]:
47+
48+
if self.press.compression_ratio == 0:
49+
return keys, values
50+
51+
assert attentions is None, "ChunkPress does not support attentions."
52+
53+
kv_len = keys.shape[2]
54+
55+
# 1. Calculate global scores first
56+
global_scores = self.press.score(
57+
module,
58+
hidden_states,
59+
keys,
60+
values,
61+
attentions,
62+
kwargs,
63+
)
64+
65+
# 2. Calculate actual number of complete chunks and remaining tokens
66+
num_complete_chunks = kv_len // self.chunk_length
67+
remaining_tokens = kv_len % self.chunk_length
68+
69+
# If we have no complete chunks, delegate to the underlying scorer press
70+
if num_complete_chunks == 0:
71+
return self.press.compress(module, hidden_states, keys, values, attentions, kwargs)
72+
73+
# Reshape complete chunks for score calculation
74+
if num_complete_chunks > 0:
75+
main_scores = global_scores[..., : num_complete_chunks * self.chunk_length]
76+
main_chunk_scores = main_scores.sum(dim=1).view(-1, num_complete_chunks, self.chunk_length)
77+
main_chunk_scores = main_chunk_scores.mean(dim=-1)
78+
else:
79+
main_chunk_scores = torch.empty((global_scores.shape[0], 0), device=global_scores.device)
80+
81+
# Handle remaining tokens if any
82+
if remaining_tokens > 0:
83+
remaining_scores = global_scores[..., -remaining_tokens:]
84+
remaining_chunk_score = remaining_scores.sum(dim=1).mean(dim=-1, keepdim=True)
85+
chunk_scores = torch.cat([main_chunk_scores, remaining_chunk_score], dim=-1)
86+
else:
87+
chunk_scores = main_chunk_scores
88+
89+
# 3. Calculate number of chunks to keep
90+
n_chunks_kept = max(1, int((num_complete_chunks + (remaining_tokens > 0)) * (1 - self.press.compression_ratio)))
91+
top_chunks = chunk_scores.topk(n_chunks_kept, dim=-1)
92+
93+
# 4. Create indices for selected chunks
94+
indices = []
95+
for chunk_idx in top_chunks.indices[0]:
96+
if chunk_idx < num_complete_chunks:
97+
# For complete chunks
98+
start_idx = chunk_idx * self.chunk_length
99+
chunk_indices = torch.arange(start_idx, start_idx + self.chunk_length, device=keys.device)
100+
else:
101+
# For the remaining partial chunk
102+
chunk_indices = torch.arange(num_complete_chunks * self.chunk_length, kv_len, device=keys.device)
103+
indices.append(chunk_indices)
104+
105+
indices = torch.cat(indices).sort()[0]
106+
indices = indices.view(1, 1, -1, 1).expand(keys.shape[0], keys.shape[1], -1, module.head_dim)
107+
108+
# 5. Use gather to collect selected keys and values
109+
keys = keys.gather(2, indices).contiguous()
110+
values = values.gather(2, indices).contiguous()
111+
112+
return keys, values

kvpress/presses/criticalkv_press.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from transformers.models.llama.modeling_llama import repeat_kv
99

1010
from kvpress.presses.base_press import BasePress
11-
from kvpress.presses.scorer_press import ScorerPress
1211
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
12+
from kvpress.presses.scorer_press import ScorerPress
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -49,7 +49,7 @@ def vwl1norm(values, module):
4949
# Future kernel fusion optimization could eliminate this intermediate variables to enhance performance.
5050
head_WoV_norm_list = []
5151
for head in range(V.size(1)):
52-
head_WoV = V[: , head, : , ...].matmul(Wo[head, ...].unsqueeze(0))
52+
head_WoV = V[:, head, :, ...].matmul(Wo[head, ...].unsqueeze(0))
5353
head_WoV_norm = torch.norm(head_WoV, p=1, dim=-1)
5454
head_WoV_norm_list.append(head_WoV_norm)
5555

kvpress/presses/duo_attention_press.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from io import StringIO
5-
from dataclasses import dataclass, field
64
from contextlib import contextmanager
5+
from dataclasses import dataclass, field
6+
from io import StringIO
77

8-
import torch
9-
import requests # type: ignore[import-untyped]
108
import numpy as np
9+
import requests # type: ignore[import-untyped]
10+
import torch
1111

1212
from kvpress.presses.base_press import BasePress
1313

14-
1514
PATTERNS_DICT = {
1615
"togethercomputer/Llama-2-7B-32K-Instruct": "Llama-2-7B-32K-Instruct/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
1716
"gradientai//Llama-3-8B-Instruct-Gradient-1048k": "Llama-3-8B-Instruct-Gradient-1048k/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501

tests/default_presses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from kvpress import (
7+
DuoAttentionPress,
78
ExpectedAttentionPress,
89
KnormPress,
910
RandomPress,
@@ -12,7 +13,6 @@
1213
StreamingLLMPress,
1314
ThinKPress,
1415
TOVAPress,
15-
DuoAttentionPress,
1616
)
1717

1818

tests/presses/test_duo_attention_press.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from kvpress.presses.duo_attention_press import DuoAttentionPress, PATTERNS_DICT
1+
from kvpress.presses.duo_attention_press import PATTERNS_DICT, DuoAttentionPress
22

33

44
def test_load_attention_pattern():

tests/presses/test_presses.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88
from transformers import DynamicCache
99

1010
from kvpress import (
11-
CriticalKVPress,
12-
CriticalAdaKVPress,
1311
AdaKVPress,
12+
ChunkKVPress,
1413
ChunkPress,
1514
ComposedPress,
15+
CriticalAdaKVPress,
16+
CriticalKVPress,
1617
KeyRerotationPress,
1718
KnormPress,
1819
ObservedAttentionPress,
1920
ScorerPress,
21+
SnapKVPress,
2022
ThinKPress,
2123
)
2224
from tests.default_presses import default_presses
@@ -43,9 +45,22 @@ def test_chunk_press(unit_test_model): # noqa: F811
4345
assert cache.get_seq_length() == 128
4446

4547

48+
def test_chunkkv_press(unit_test_model): # noqa: F811
49+
press = SnapKVPress(compression_ratio=0.5)
50+
for chunk_length in [2, 4, 8, 128]:
51+
composed_press = ChunkKVPress(press=press, chunk_length=chunk_length)
52+
with composed_press(unit_test_model):
53+
input_ids = torch.randint(0, 1024, (1, 256))
54+
cache = DynamicCache()
55+
unit_test_model(input_ids, past_key_values=cache).past_key_values
56+
assert cache.get_seq_length() == 128
57+
58+
4659
@pytest.mark.parametrize("press_dict", default_presses)
47-
@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress,
48-
CriticalKVPress, CriticalAdaKVPress])
60+
@pytest.mark.parametrize(
61+
"wrapper_press",
62+
[None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress, CriticalKVPress, CriticalAdaKVPress],
63+
)
4964
def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
5065
cls = press_dict["cls"]
5166
for kwargs in press_dict["kwargs"]:

0 commit comments

Comments
 (0)