Skip to content

Commit 595c972

Browse files
committed
Rename ThresholdPress to DMSPress
Signed-off-by: SimJeg <sjegou@nvidia.com>
1 parent 50c2ae5 commit 595c972

File tree

12 files changed

+34
-33
lines changed

12 files changed

+34
-33
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
130130
- `LeverageScorePress` ([source](kvpress/presses/leverage_press.py), [paper](https://arxiv.org/abs/2507.08143)): evicts tokens based on approximate statistical leverage (i.e we preserve outliers in the key space).
131131
- `CompactorPress` ([source](kvpress/presses/compactor_press.py), [paper](https://arxiv.org/abs/2507.08143)): blends `NonCausalAttnPress` and `LeverageScorePress` based on the compression_ratio.
132132
- `CURPress` ([source](kvpress/presses/cur_press.py), [paper](https://arxiv.org/abs/2509.15038)): prune keys and values based on the CUR decomposition using approximate leverage scores.
133-
- `KVzapPress` ([source](kvpress/presses/kvzap/kvzap_press.py), [paper](https://arxiv.org/abs/2601.07891), [training](kvzap)): approximate KVzip+ using a fast surrogate model. To be used in conjunction with the `ThresholdPress`.
133+
- `KVzapPress` ([source](kvpress/presses/kvzap/kvzap_press.py), [paper](https://arxiv.org/abs/2601.07891), [training](kvzap)): approximate KVzip+ using a fast surrogate model. To be used in conjunction with the `DMSPress`.
134134

135135
Some presses rely on a different logic:
136136
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/abs/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
@@ -150,7 +150,7 @@ Finally we provide wrapper presses that can be combined with other presses:
150150
- `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
151151
- `DecodingPress` ([source](kvpress/presses/decoding_press.py)): allows for compression during decoding, see decoding section in this README.
152152
- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): allows to compress both during prefilling and during decoding.
153-
- `ThresholdPress` ([source](kvpress/presses/threshold_press.py)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True).
153+
- `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True).
154154

155155
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
156156

evaluation/evaluate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
ObservedAttentionPress,
2929
ScorerPress,
3030
ThinKPress,
31-
ThresholdPress,
31+
DMSPress,
3232
)
3333

3434
logger = logging.getLogger(__name__)
@@ -256,10 +256,10 @@ def _setup_press(self):
256256
if isinstance(press, DuoAttentionPress):
257257
press.head_compression_ratio = compression_ratio
258258
logger.info(f"Set DuoAttentionPress head_compression_ratio to {compression_ratio}")
259-
elif isinstance(press, ThresholdPress):
260-
assert self.config.threshold is not None, "threshold must be set for ThresholdPress"
259+
elif isinstance(press, DMSPress):
260+
assert self.config.threshold is not None, "threshold must be set for DMSPress"
261261
press.threshold = self.config.threshold
262-
logger.info(f"Set ThresholdPress threshold to {press.threshold}")
262+
logger.info(f"Set DMSPress threshold to {press.threshold}")
263263
elif isinstance(press, ComposedPress):
264264
for ps in press.presses:
265265
if isinstance(ps, ThinKPress):

evaluation/evaluate_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ data_dir: "4096" # Subdirectory of the dataset
1010
press_name: "knorm" # see PRESS_REGISTRY in evaluate_registry.py
1111
compression_ratio: 0.5 # Compression ratio for the press (0.0 to 1.0)
1212
key_channel_compression_ratio: null # For ThinKPress and ComposedPress (0.0 to 1.0)
13-
threshold: null # For ThresholdPress
13+
threshold: null # For DMSPress
1414

1515
fraction: 1.0 # Fraction of dataset to evaluate (0.0 to 1.0), for quick testing
1616
max_new_tokens: null # Maximum new tokens to generate (null = use dataset default)

evaluation/evaluate_registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
RandomPress,
3535
SnapKVPress,
3636
StreamingLLMPress,
37-
ThresholdPress,
37+
DMSPress,
3838
ThinKPress,
3939
TOVAPress,
4040
CURPress,
@@ -86,8 +86,8 @@
8686
"keydiff": KeyDiffPress(),
8787
"kvzip": KVzipPress(),
8888
"kvzip_plus": KVzipPress(kvzip_plus_normalization=True),
89-
"kvzap_linear": ThresholdPress(press=KVzapPress(model_type="linear")),
90-
"kvzap_mlp": ThresholdPress(press=KVzapPress(model_type="mlp")),
89+
"kvzap_linear": DMSPress(press=KVzapPress(model_type="linear")),
90+
"kvzap_mlp": DMSPress(press=KVzapPress(model_type="mlp")),
9191
"kvzap_mlp_head": KVzapPress(model_type="mlp"),
9292
"kvzap_mlp_layer": AdaKVPress(KVzapPress(model_type="mlp")),
9393
"lagkv": LagKVPress(),

kvpress/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from kvpress.presses.snapkv_press import SnapKVPress
3838
from kvpress.presses.streaming_llm_press import StreamingLLMPress
3939
from kvpress.presses.think_press import ThinKPress
40-
from kvpress.presses.threshold_press import ThresholdPress
40+
from kvpress.presses.dms_press import DMSPress
4141
from kvpress.presses.tova_press import TOVAPress
4242

4343
# Patch the attention functions to support head-wise compression
@@ -80,5 +80,5 @@
8080
"LeverageScorePress",
8181
"NonCausalAttnPress",
8282
"KVzapPress",
83-
"ThresholdPress",
83+
"DMSPress",
8484
]

kvpress/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from kvpress.presses.finch_press import FinchPress
1717
from kvpress.presses.key_rerotation_press import KeyRerotationPress
1818
from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
19-
from kvpress.presses.threshold_press import ThresholdPress
19+
from kvpress.presses.dms_press import DMSPress
2020

2121
logger = logging.getLogger(__name__)
2222

@@ -224,7 +224,7 @@ def _forward(
224224

225225
# We only perform decoding compression if the press is a decoding or prefill decoding press
226226
perform_decoding_compression = press is not None and isinstance(press, (DecodingPress, PrefillDecodingPress))
227-
if isinstance(press, ThresholdPress):
227+
if isinstance(press, DMSPress):
228228
perform_decoding_compression = press.decoding
229229
with press(self.model) if perform_decoding_compression else contextlib.nullcontext():
230230
# Greedy decoding for each question
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313

1414

1515
@dataclass
16-
class ThresholdPress(BasePress):
16+
class DMSPress(BasePress):
1717
"""
18+
Based on Dynamic Memory Sparsification (DMS, https://arxiv.org/abs/2506.05345) inference.
1819
Wraps a ScorerPress and evicts keys/values with scores below a given threshold.
1920
20-
Unlike most presses that use a fixed compression_ratio, ThresholdPress uses a score threshold
21+
Unlike most presses that use a fixed compression_ratio, DMSPress uses a score threshold
2122
to determine which KV pairs to evict. This allows for adaptive compression where the actual
2223
compression ratio depends on the input content.
2324

kvpress/presses/kvzap_press.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class KVzapPress(ScorerPress):
5050
KVzap (https://arxiv.org/abs/2601.07891) is a fast approximation of KVzip that works
5151
in both prefilling and decoding. It applies a lightweight surrogate model to the hidden
5252
states to predict importance scores for every KV pair.
53-
KVzapPress is designed to be used in conjunction with the ThresholdPress
53+
KVzapPress is designed to be used in conjunction with the DMSPress
5454
model_type can be "linear" or "mlp".
5555
"""
5656

kvzap/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77

88
## Usage
99

10-
KVzap is designed to be used by combining the `KVzapPress` and the `ThresholdPress` from kvpress:
10+
KVzap is designed to be used by combining the `KVzapPress` and the `DMSPress` from kvpress:
1111

1212
```python
1313
import requests
1414
from transformers import pipeline
15-
from kvpress import KVzapPress, ThresholdPress
15+
from kvpress import KVzapPress, DMSPress
1616

1717
model = "Qwen/Qwen3-8B"
1818
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", dtype="auto")
19-
press = ThresholdPress(KVzapPress(model_type="mlp"), threshold=-4)
19+
press = DMSPress(KVzapPress(model_type="mlp"), threshold=-4)
2020

2121
# Prefilling compression only, thinking disabled
2222
press.decoding = False
@@ -32,7 +32,7 @@ answer = pipe(prompt, press=press, enable_thinking=True, max_new_tokens=2000)["a
3232
print(f"Compression ratio: {press.compression_ratio:.2%}\nAnswer: {answer}")
3333
```
3434

35-
The `KVzapPress` inherits from the `ScorerPress` class and only predicts the scores for every KV pair. The `ThresholdPress` then prunes the KV pairs with a score below a given threshold, rather than using a fixed compression ratio.
35+
The `KVzapPress` inherits from the `ScorerPress` class and only predicts the scores for every KV pair. The `DMSPress` then prunes the KV pairs with a score below a given threshold, rather than using a fixed compression ratio.
3636

3737
Supported base models are provided in the [KVzap collection](https://huggingface.co/collections/nvidia/kvzap) but can easily be extended to any other model following the instructions in the [training section](#training).
3838

kvzap/evaluate_aime.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from transformers import AutoTokenizer, AutoModelForCausalLM
1111
from datasets import load_dataset
1212

13-
from kvpress import KVzapPress, ThresholdPress
13+
from kvpress import KVzapPress, DMSPress
1414

1515

1616
def calculate_metrics(df):
@@ -56,11 +56,11 @@ def evaluate(
5656
"""
5757

5858
# Create press
59-
press: ThresholdPress | type[nullcontext[None]]
59+
press: DMSPress | type[nullcontext[None]]
6060
if kvzap_model_type == "no_press":
6161
press = nullcontext
6262
else:
63-
press = ThresholdPress(
63+
press = DMSPress(
6464
KVzapPress(model_type=kvzap_model_type),
6565
threshold=threshold,
6666
decoding=True,
@@ -86,7 +86,7 @@ def evaluate(
8686
)
8787
answer = tokenizer.decode(output_tokens[0, tokens.shape[1] :])
8888
df.loc[idx, "predicted_answer"] = answer
89-
if isinstance(press, ThresholdPress):
89+
if isinstance(press, DMSPress):
9090
df.loc[idx, "compression_ratio"] = press.compression_ratio
9191
else:
9292
df.loc[idx, "compression_ratio"] = 0

0 commit comments

Comments
 (0)