Skip to content

Commit 50c2ae5

Browse files
authored
Add KVzapPress and ThresholdPress (#171)
1 parent a86861a commit 50c2ae5

24 files changed

+1079
-42
lines changed

README.md

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-blue)](https://huggingface.co/spaces/nvidia/kvpress)
55
[![Blog post](https://img.shields.io/badge/🤗%20Hugging%20Face-Blog-blue)](https://huggingface.co/blog/nvidia/kvpress)
66
[![Hugging Face Leaderboard](https://img.shields.io/badge/🤗%20HuggingFace-Leaderboard-orange)](https://huggingface.co/spaces/nvidia/kvpress-leaderboard)
7-
[![Paper](https://img.shields.io/badge/📄%20arXiv-Paper-red)](https://arxiv.org/abs/2510.00636v1)
7+
[![arXiv](https://img.shields.io/badge/arXiv-2510.00636-b31b1b.svg)](https://arxiv.org/abs/2510.00636v1)
8+
89

910
![kvpress](kvpress.jpg)
1011

@@ -54,10 +55,8 @@ KVPress provides a set of "presses" that compress the KV cache during the prefil
5455
from transformers import pipeline
5556
from kvpress import ExpectedAttentionPress
5657

57-
device = "cuda:0"
58-
model = "meta-llama/Llama-3.1-8B-Instruct"
59-
model_kwargs = {"attn_implementation": "flash_attention_2"}
60-
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
58+
model = "Qwen/Qwen3-8B"
59+
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", dtype="auto")
6160

6261
context = "A very long text you want to compress once and for all"
6362
question = "\nA question about the compressed context" # optional
@@ -71,7 +70,7 @@ In the snippet above, the compression is only applied on the context tokens so t
7170
<details><summary>
7271
Decoding Compression
7372
</summary>
74-
By default, KVPress applies compression during the pre-filling phase. As a new (experimental) feature, we now support decoding compression via the `DecodingPress` wrapper. `DecodingPress` compresses the KV cache periodically during token generation, optionally maintaining a buffer of recent hidden states. `DecodingPress` supports the following parameters:
73+
By default, KVPress applies compression during the prefilling phase. As a new (experimental) feature, we now support decoding compression via the `DecodingPress` wrapper. `DecodingPress` compresses the KV cache periodically during token generation, optionally maintaining a buffer of recent hidden states. `DecodingPress` supports the following parameters:
7574

7675
- `base_press`: Any ScorerPress (e.g., `KNormPress`, `CriticalKVPress`)
7776
- `compression_interval`: Steps between compressions (default: 10)
@@ -122,7 +121,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
122121
- `ExpectedAttentionPress` ([source](kvpress/presses/expected_attention_press.py), [notebook](notebooks/expected_attention.ipynb)): expected attention weight during the generation phase
123122
- `StreamingLLMPress` ([source](kvpress/presses/streaming_llm_press.py), [paper](https://arxiv.org/abs/2309.17453)): keep only the initial and recent tokens
124123
- `TOVAPress` ([source](kvpress/presses/tova_press.py), [paper](https://arxiv.org/abs/2401.06104)): attention weight of the last query averaged across heads
125-
- `ObservedAttentionPress` ([source](kvpress/presses/observed_attention_press.py), [paper](https://arxiv.org/abs/2306.14048)): average attention weight observed during in pre-filling phase
124+
- `ObservedAttentionPress` ([source](kvpress/presses/observed_attention_press.py), [paper](https://arxiv.org/abs/2306.14048)): average attention weight observed during in prefilling phase
126125
- `QFilterPress` ([source](kvpress/presses/qfilter_press.py), [paper](https://arxiv.org/abs/2503.02812)): project the Key representations on the main SVD component of the Query vectors to approximate the attention scores.
127126
- `PyramidKVPress` ([source](kvpress/presses/pyramidkv_press.py), [paper](https://arxiv.org/abs/2406.02069)): maintain pyramid-like cache sizes, allocating more cache budget to lower layers and less to higher layers
128127
- `LagKVPress` ([source](kvpress/presses/lagkv_press.py), [paper](https://arxiv.org/abs/2504.04704)): leverage on the KV lag-relative information to compress. It's query free, attention-weight free, and flash-attention compatible.
@@ -131,9 +130,10 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
131130
- `LeverageScorePress` ([source](kvpress/presses/leverage_press.py), [paper](https://arxiv.org/abs/2507.08143)): evicts tokens based on approximate statistical leverage (i.e we preserve outliers in the key space).
132131
- `CompactorPress` ([source](kvpress/presses/compactor_press.py), [paper](https://arxiv.org/abs/2507.08143)): blends `NonCausalAttnPress` and `LeverageScorePress` based on the compression_ratio.
133132
- `CURPress` ([source](kvpress/presses/cur_press.py), [paper](https://arxiv.org/abs/2509.15038)): prune keys and values based on the CUR decomposition using approximate leverage scores.
133+
- `KVzapPress` ([source](kvpress/presses/kvzap/kvzap_press.py), [paper](https://arxiv.org/abs/2601.07891), [training](kvzap)): approximate KVzip+ using a fast surrogate model. To be used in conjunction with the `ThresholdPress`.
134134

135135
Some presses rely on a different logic:
136-
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
136+
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/abs/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
137137
- `SimLayerKVPress` ([source](kvpress/presses/simlayerkv_press.py), [paper](https://arxiv.org/abs/2410.13846)): identify "lazy" layers, and apply the StreamingLLM approach to them
138138
- `DuoAttentionPress` ([source](kvpress/presses/duo_attention_press.py), [paper](https://arxiv.org/abs/2410.10819)): split heads into retrieval heads (no compression) and streaming heads (StreamingLLM approach)
139139
- `FinchPress` ([source](kvpress/presses/finch_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): similar to SnapKV with a dynamic window size and key value re-rotation
@@ -148,8 +148,9 @@ Finally we provide wrapper presses that can be combined with other presses:
148148
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
149149
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.
150150
- `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
151-
- `DecodingPress` ([source](kvpress/presses/decoding_press.py)): Allows for compression during decoding, see decoding section in this README.
152-
- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): Allows to compress both during prefilling and during decoding.
151+
- `DecodingPress` ([source](kvpress/presses/decoding_press.py)): allows for compression during decoding, see decoding section in this README.
152+
- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): allows to compress both during prefilling and during decoding.
153+
- `ThresholdPress` ([source](kvpress/presses/threshold_press.py)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True).
153154

154155
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
155156

@@ -164,11 +165,6 @@ Please refer to the [evaluation](evaluation/README.md) directory in this repo fo
164165

165166
Below we report the average performance on the RULER dataset with 4k context length for different presses, from our [![Hugging Face Leaderboard](https://img.shields.io/badge/🤗%20HuggingFace-Leaderboard-orange)](https://huggingface.co/spaces/nvidia/kvpress-leaderboard)
166167

167-
<p>
168-
<img src="leaderboard_plot_score.png" alt="Leaderboard">
169-
</p>
170-
171-
172168
## Quantization
173169

174170
We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline:
@@ -242,7 +238,7 @@ Memory usage should be reduced by around `compression_ratio * kv_cache_size`. As
242238

243239
### How does a press work ? </summary>
244240

245-
A press registers a forward hook (`press.forward_hook` method) to each attention layer during the pre-filling phase. Registration can be applied using the press as a context manager (`press.__call__` method):
241+
A press registers a forward hook (`press.forward_hook` method) to each attention layer during the prefilling phase. Registration can be applied using the press as a context manager (`press.__call__` method):
246242

247243
```python
248244
import torch

evaluation/benchmarks/aime25/calculate_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def extract_boxed(pred_answer):
88
try:
9-
return str(pred_answer.split("boxed{")[1].split("}")[0])
9+
return str(pred_answer.split("boxed{")[-1].split("}")[0])
1010
except IndexError:
1111
return None
1212

evaluation/benchmarks/longbench/calculate_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def scorer(dataset, predictions, answers, all_classes):
6060
for prediction, ground_truths in zip(predictions, answers):
6161
score = 0.0
6262
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
63-
prediction = prediction.lstrip("\n").split("\n")[0]
63+
prediction = prediction.lstrip().split("\n")[0]
6464
for ground_truth in ground_truths:
65-
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
65+
score = max(score, dataset2metric[dataset](prediction.lstrip(), ground_truth, all_classes=all_classes))
6666
total_score += score
6767
return round(100 * total_score / len(predictions), 2)
6868

evaluation/evaluate.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from evaluate_registry import DATASET_REGISTRY, PRESS_REGISTRY, SCORER_REGISTRY
1919
from fire import Fire
2020
from tqdm import tqdm
21-
from transformers import Pipeline, pipeline
21+
from transformers import FineGrainedFP8Config, Pipeline, pipeline
2222

2323
from kvpress import (
2424
ComposedPress,
@@ -28,6 +28,7 @@
2828
ObservedAttentionPress,
2929
ScorerPress,
3030
ThinKPress,
31+
ThresholdPress,
3132
)
3233

3334
logger = logging.getLogger(__name__)
@@ -45,6 +46,7 @@ class EvaluationConfig:
4546
press_name: str = "knorm"
4647
compression_ratio: float = 1.0
4748
key_channel_compression_ratio: Optional[float] = None
49+
threshold: Optional[float] = None
4850

4951
# Dataset and generation parameters
5052
fraction: float = 1.0
@@ -71,6 +73,9 @@ class EvaluationConfig:
7173
# For reproducibility
7274
seed: int = 42
7375

76+
# Quantization
77+
fp8: bool = False
78+
7479
def __post_init__(self):
7580
"""Validate configuration after initialization."""
7681
# Validate dataset
@@ -85,11 +90,6 @@ def __post_init__(self):
8590
logger.info("Using 'no_press' configuration. Overriding compression_ratio to 0.0")
8691
self.compression_ratio = 0.0
8792

88-
# Validate compression ratios
89-
assert (
90-
0.0 <= self.compression_ratio <= 1.0
91-
), f"compression_ratio must be between 0.0 and 1.0, got {self.compression_ratio}"
92-
9393
# Only validate key_channel_compression_ratio if it's not None
9494
if self.key_channel_compression_ratio is not None:
9595
assert (
@@ -115,8 +115,6 @@ def get_results_dir(self, output_dir: Path) -> Path:
115115
----------
116116
output_dir : Path
117117
The output directory path
118-
press
119-
The press instance to check for ThinKPress components
120118
121119
Returns
122120
-------
@@ -132,6 +130,8 @@ def get_results_dir(self, output_dir: Path) -> Path:
132130
f"{self.compression_ratio:.2f}",
133131
]
134132

133+
if self.threshold is not None:
134+
components[-1] = f"{self.threshold:.2f}"
135135
if self.fraction < 1.0:
136136
components.append(f"fraction{self.fraction:.3f}")
137137
if self.max_context_length is not None:
@@ -256,6 +256,10 @@ def _setup_press(self):
256256
if isinstance(press, DuoAttentionPress):
257257
press.head_compression_ratio = compression_ratio
258258
logger.info(f"Set DuoAttentionPress head_compression_ratio to {compression_ratio}")
259+
elif isinstance(press, ThresholdPress):
260+
assert self.config.threshold is not None, "threshold must be set for ThresholdPress"
261+
press.threshold = self.config.threshold
262+
logger.info(f"Set ThresholdPress threshold to {press.threshold}")
259263
elif isinstance(press, ComposedPress):
260264
for ps in press.presses:
261265
if isinstance(ps, ThinKPress):
@@ -349,6 +353,11 @@ def _setup_model_pipeline(self):
349353
logger.info(f"No device specified, auto-detected device: {device}")
350354

351355
model_kwargs = self.config.model_kwargs or {}
356+
357+
if self.config.fp8:
358+
model_kwargs["quantization_config"] = FineGrainedFP8Config()
359+
logger.info("FP8 quantization enabled.")
360+
352361
if isinstance(self.press, ObservedAttentionPress):
353362
model_kwargs["attn_implementation"] = "eager"
354363
logger.info("ObservedAttentionPress detected, setting attn_implementation to 'eager'.")

evaluation/evaluate_config.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
77
dataset: "ruler" # see DATASET_REGISTRY in evaluate_registry.py
88
data_dir: "4096" # Subdirectory of the dataset (if applicable) else leave "null"
99

10-
press_name: "knorm" # see PRESS_REGISTRY in evaluate_registry.py
11-
compression_ratio: 0.5 # Compression ratio for the press (0.0 to 1.0)
12-
key_channel_compression_ratio: null # For ThinKPress and ComposedPress (0.0 to 1.0)
10+
press_name: "knorm" # see PRESS_REGISTRY in evaluate_registry.py
11+
compression_ratio: 0.5 # Compression ratio for the press (0.0 to 1.0)
12+
key_channel_compression_ratio: null # For ThinKPress and ComposedPress (0.0 to 1.0)
13+
threshold: null # For ThresholdPress
1314

1415
fraction: 1.0 # Fraction of dataset to evaluate (0.0 to 1.0), for quick testing
1516
max_new_tokens: null # Maximum new tokens to generate (null = use dataset default)
@@ -18,6 +19,7 @@ query_aware: false # Whether to include question
1819
needle_depth: null # Depth (int or list of ints) percentage of the needle in the haystack (0 to 100), only for needle_in_haystack dataset
1920

2021
device: null # Device to use (null = auto-detect, "cuda:0", "cpu", etc.)
22+
fp8: false # Whether to use FP8 quantization (FineGrainedFP8Config() from transformers)
2123

2224
# You can add any model kwargs here.
2325
model_kwargs:

evaluation/evaluate_registry.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,19 @@
2626
FinchPress,
2727
KeyDiffPress,
2828
KnormPress,
29+
KVzapPress,
2930
KVzipPress,
3031
ObservedAttentionPress,
3132
PyramidKVPress,
3233
QFilterPress,
3334
RandomPress,
3435
SnapKVPress,
3536
StreamingLLMPress,
37+
ThresholdPress,
3638
ThinKPress,
3739
TOVAPress,
40+
CURPress,
41+
LagKVPress,
3842
)
3943

4044
# These dictionaries define the available datasets, scorers, and KVPress methods for evaluation.
@@ -67,22 +71,26 @@
6771

6872

6973
PRESS_REGISTRY = {
70-
"adakv_expected_attention": AdaKVPress(ExpectedAttentionPress()),
71-
"adakv_expected_attention_e2": AdaKVPress(ExpectedAttentionPress(epsilon=1e-2)),
7274
"adakv_snapkv": AdaKVPress(SnapKVPress()),
7375
"block_keydiff": BlockPress(press=KeyDiffPress(), block_size=128),
7476
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
7577
"critical_adakv_expected_attention": CriticalAdaKVPress(ExpectedAttentionPress(use_vnorm=False)),
7678
"critical_adakv_snapkv": CriticalAdaKVPress(SnapKVPress()),
7779
"critical_expected_attention": CriticalKVPress(ExpectedAttentionPress(use_vnorm=False)),
7880
"critical_snapkv": CriticalKVPress(SnapKVPress()),
81+
"cur": CURPress(),
7982
"duo_attention": DuoAttentionPress(),
8083
"duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
81-
"expected_attention": ExpectedAttentionPress(),
84+
"expected_attention": AdaKVPress(ExpectedAttentionPress(epsilon=1e-2)),
8285
"finch": FinchPress(),
8386
"keydiff": KeyDiffPress(),
8487
"kvzip": KVzipPress(),
8588
"kvzip_plus": KVzipPress(kvzip_plus_normalization=True),
89+
"kvzap_linear": ThresholdPress(press=KVzapPress(model_type="linear")),
90+
"kvzap_mlp": ThresholdPress(press=KVzapPress(model_type="mlp")),
91+
"kvzap_mlp_head": KVzapPress(model_type="mlp"),
92+
"kvzap_mlp_layer": AdaKVPress(KVzapPress(model_type="mlp")),
93+
"lagkv": LagKVPress(),
8694
"knorm": KnormPress(),
8795
"observed_attention": ObservedAttentionPress(),
8896
"pyramidkv": PyramidKVPress(),

evaluation/leaderboard.sh

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Script to run the leaderboard evaluation on 4 GPUs
5+
dataset="ruler"
6+
data_dir="4096"
7+
model="Qwen/Qwen3-8B"
8+
output_dir="./results_lb"
9+
10+
# Loop 1: presses not requiring to include the questions in the compression
11+
press_names=("random" "knorm" "snapkv" "expected_attention" "streaming_llm" "tova" "observed_attention" "qfilter" "pyramidkv" "lagkv" "keydiff" "adakv_compactor" "cur" "duo_attention" "duo_attention_on_the_fly" "kvzip")
12+
13+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name no_press --compression_ratio 0.00 --output_dir $output_dir --device "cuda:0"
14+
15+
for press in "${press_names[@]}"; do
16+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.25 --output_dir $output_dir --device "cuda:0" &
17+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.50 --output_dir $output_dir --device "cuda:1" &
18+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.75 --output_dir $output_dir --device "cuda:2" &
19+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.875 --output_dir $output_dir --device "cuda:3" &
20+
wait
21+
done
22+
23+
# Use -3, -4, -5, -6 for Qwen3-8B and -6, -7, -8, -9 for Llama-3.1-8B-Instruct
24+
for press in "kvzap_linear" "kvzap_mlp"; do
25+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -3 --output_dir $output_dir --device "cuda:0" &
26+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -4 --output_dir $output_dir --device "cuda:1" &
27+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -5 --output_dir $output_dir --device "cuda:2" &
28+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -6 --output_dir $output_dir --device "cuda:3" &
29+
wait
30+
done
31+
32+
# Loop 2: presses requiring to compress questions
33+
press_names=("snapkv" "adakv_snapkv" "finch" "chunkkv")
34+
for press in "${press_names[@]}"; do
35+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.25 --output_dir $output_dir --device "cuda:0" --query_aware &
36+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.50 --output_dir $output_dir --device "cuda:1" --query_aware &
37+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.75 --output_dir $output_dir --device "cuda:2" --query_aware &
38+
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.875 --output_dir $output_dir --device "cuda:3" --query_aware &
39+
wait
40+
done

0 commit comments

Comments
 (0)