Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,7 @@ def _load_and_prepare_dataset(self):
# FinchPress uses a delimiter token to separate context and question
# So we need to update the tokenizer and the model embeddings.
logger.info("FinchPress detected, updating model and tokenizer with delimiter token.")
self.press.update_model_and_tokenizer(
self.pipeline.model, self.pipeline.tokenizer
) # type: ignore[attr-defined]
self.press.update_model_and_tokenizer(self.pipeline.model, self.pipeline.tokenizer) # type: ignore[attr-defined]
df["context"] = df["context"] + self.press.delimiter_token # type: ignore[attr-defined, index]

if self.config.compress_questions:
Expand Down
3 changes: 3 additions & 0 deletions kvpress/presses/adakv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __post_init__(self):
assert isinstance(self.press, ScorerPress), "AdaKVPress requires a ScorerPress as input"
assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in [0, 1]"

def post_init_from_model(self, model):
self.press.post_init_from_model(model)

@property
def compression_ratio(self):
return self.press.compression_ratio
Expand Down
7 changes: 7 additions & 0 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class BasePress:
The compression is applied only during pre-filling (not during generation).
"""

def post_init_from_model(self, model: PreTrainedModel):
"""
Optional method to initialize press parameters from the model
"""
pass

def compress(
self,
module: nn.Module,
Expand Down Expand Up @@ -179,6 +185,7 @@ def __call__(self, model: PreTrainedModel) -> Generator:
if isinstance(model, Gemma3ForConditionalGeneration):
logger.warning_once("Compression in Gemma3 is only applied to layer without sliding window attention")

self.post_init_from_model(model)
hooks = []
try:
language_model = model.model.language_model if hasattr(model.model, "language_model") else model.model
Expand Down
3 changes: 3 additions & 0 deletions kvpress/presses/block_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class BlockPress(BasePress):
def __post_init__(self):
assert isinstance(self.press, ScorerPress), "BlockPress requires a ScorerPress"

def post_init_from_model(self, model):
self.press.post_init_from_model(model)

@property
def compression_ratio(self):
return self.press.compression_ratio
Expand Down
3 changes: 3 additions & 0 deletions kvpress/presses/chunk_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class ChunkPress(BasePress):
def __post_init__(self):
assert isinstance(self.press, ScorerPress), "ChunkPress requires a ScorerPress as input"

def post_init_from_model(self, model):
self.press.post_init_from_model(model)

@property
def compression_ratio(self):
return self.press.compression_ratio
Expand Down
3 changes: 3 additions & 0 deletions kvpress/presses/chunkkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class ChunkKVPress(BasePress):
def __post_init__(self):
assert isinstance(self.press, ScorerPress), "ChunkKVPress requires a ScorerPress as input"

def post_init_from_model(self, model):
self.press.post_init_from_model(model)

@property
def compression_ratio(self):
return self.press.compression_ratio
Expand Down
4 changes: 4 additions & 0 deletions kvpress/presses/composed_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __post_init__(self):
isinstance(press, (AdaKVPress, KVzipPress)) for press in self.presses
), "ComposedPress cannot contains AdaKVPress or KVzipPress"

def post_init_from_model(self, model):
for press in self.presses:
press.post_init_from_model(model)

def forward_hook(self, module, input, kwargs, output):
retained_fraction = 1.0
for press in self.presses:
Expand Down
3 changes: 3 additions & 0 deletions kvpress/presses/criticalkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(self, press: ScorerPress, epsilon: float = 1e-4, first_stage_ratio:
if isinstance(self.press, ExpectedAttentionPress) and self.press.use_vnorm:
logger.warning("use_vnorm should be disabled for CriticalKVPress")

def post_init_from_model(self, model):
self.press.post_init_from_model(model)

@property # type: ignore[misc]
def compression_ratio(self): #
return self.press.compression_ratio
Expand Down
3 changes: 3 additions & 0 deletions kvpress/presses/decoding_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def __post_init__(self):
f"This will be overridden by the decoding press."
)

def post_init_from_model(self, model):
self.base_press.post_init_from_model(model)

def compress(
self,
module: nn.Module,
Expand Down
11 changes: 2 additions & 9 deletions kvpress/presses/duo_attention_press.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
from dataclasses import dataclass, field
from io import StringIO

Expand Down Expand Up @@ -70,7 +69,7 @@ class DuoAttentionPress(BasePress):
sink_size: int = field(init=False, default=None)
streaming_mask: torch.Tensor = field(init=False, default=None)

def __post_init_from_model__(self, model):
def post_init_from_model(self, model):
"""
Initialize sink_size, recent_size, and streaming_mask from a model
"""
Expand Down Expand Up @@ -101,7 +100,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
assert module.config._attn_implementation != "eager", "eager mode not supported"
if self.streaming_mask is None:
raise ValueError(
"Streaming mask not initialized. Make sure to call __post_init_from_model__ to initialize this press."
"Streaming mask not initialized. Make sure to call post_init_from_model to initialize this press."
)
k_len = keys.shape[2]

Expand Down Expand Up @@ -141,12 +140,6 @@ def load_attention_pattern(model):

return config["sink_size"], config["recent_size"], head_scores

@contextmanager
def __call__(self, model):
self.__post_init_from_model__(model)
with super().__call__(model):
yield


@cached(cache, key=lambda model, num_samples=50, q_len=500: (model.config.name_or_path, num_samples, q_len))
def duo_attention_on_the_fly(model, num_samples=50, q_len=500):
Expand Down
12 changes: 3 additions & 9 deletions kvpress/presses/expected_attention_with_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class ExpectedAttentionStatsPress(ExpectedAttentionPress):
dataset_name: str = "kmfoda/booksum"
stats_folder: Optional[str] = None

mu: torch.Tensor = field(init=False, default=None) # initialized in __post_init_from_model__
cov: torch.Tensor = field(init=False, default=None) # initialized in __post_init_from_model__
mu: torch.Tensor = field(init=False, default=None) # initialized in post_init_from_model
cov: torch.Tensor = field(init=False, default=None) # initialized in post_init_from_model

def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor):
"""
Expand All @@ -69,7 +69,7 @@ def available_stats():
collection = get_collection("alessiodevoto/expectedattentionstats-68b0248d519303713320e2cf")
return [x.item_id for x in collection.items]

def __post_init_from_model__(self, model):
def post_init_from_model(self, model):
"""
Automatically load or compute query statistics for the model.
"""
Expand Down Expand Up @@ -104,12 +104,6 @@ def _maybe_load_stats_from_hub(self, model: PreTrainedModel):
"```"
)

@contextmanager
def __call__(self, model):
self.__post_init_from_model__(model)
with super().__call__(model):
yield


class ExpectedAttentionStats(torch.nn.Module, PyTorchModelHubMixin):
"""
Expand Down
3 changes: 3 additions & 0 deletions kvpress/presses/key_rerotation_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class KeyRerotationPress(BasePress):
def __post_init__(self):
assert isinstance(self.press, ScorerPress)

def post_init_from_model(self, model):
self.press.post_init_from_model(model)

@property
def compression_ratio(self):
return self.press.compression_ratio
Expand Down
6 changes: 6 additions & 0 deletions kvpress/presses/prefill_decoding_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class PrefillDecodingPress(BasePress):
prefilling_press: Optional[BasePress] = None
decoding_press: Optional[DecodingPress] = None

def post_init_from_model(self, model):
if self.prefilling_press is not None:
self.prefilling_press.post_init_from_model(model)
if self.decoding_press is not None:
self.decoding_press.post_init_from_model(model)

def compress(
self,
module: nn.Module,
Expand Down
11 changes: 2 additions & 9 deletions kvpress/presses/qfilter_press.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import cache

Expand Down Expand Up @@ -51,7 +50,7 @@ class QFilterPress(ScorerPress):

q_filters: QFilters = field(init=False, default=None)

def __post_init_from_model__(self, model):
def post_init_from_model(self, model):
model_name = model.config.name_or_path.split("/")[-1]
self.q_filters = self.load_q_filters(model_name)
self.q_filters = self.q_filters.to(model.dtype)
Expand All @@ -75,15 +74,9 @@ def available_qfilters():
def score(self, module, hidden_states, keys, values, attentions, kwargs):
if self.q_filters is None:
raise ValueError(
"Q-filters not loaded. If you are using a wrapper press, make sure to call __post_init_from_model__."
"Q-filters not loaded. If you are using a wrapper press, make sure to call post_init_from_model."
)
q_filter = self.q_filters[module.layer_idx][None, :, None] # type: ignore
q_filter = q_filter.to(keys.device)
scores = -(q_filter * keys).sum(dim=-1)
return scores

@contextmanager
def __call__(self, model):
self.__post_init_from_model__(model)
with super().__call__(model):
yield
10 changes: 5 additions & 5 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
for kwargs in press_dict["kwargs"]:
press = cls(**kwargs)
if wrapper_press is not None:
if hasattr(press, "__post_init_from_model__"):
press.__post_init_from_model__(unit_test_model)
if hasattr(press, "post_init_from_model"):
press.post_init_from_model(unit_test_model)
if issubclass(wrapper_press, ComposedPress):
if isinstance(press, KVzipPress): # KVzipPress is currently not compatible with ComposedPress
return
Expand All @@ -80,9 +80,9 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
elif issubclass(wrapper_press, ChunkPress):
press = ChunkPress(press=press, chunk_length=24)

# TODO: Handle __post_init_from_model__ differently
if hasattr(press, "__post_init_from_model__"):
press.__post_init_from_model__(unit_test_model)
# TODO: Handle post_init_from_model differently
if hasattr(press, "post_init_from_model"):
press.post_init_from_model(unit_test_model)
with press(unit_test_model):
input_ids = torch.randint(0, 1024, (1, 128), device=unit_test_model.device)
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
Expand Down
67 changes: 0 additions & 67 deletions tests/presses/test_wrappers.py

This file was deleted.

2 changes: 0 additions & 2 deletions tests/test_decoding_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,6 @@ def test_all_presses_work_with_decoding_press(press_config):
# CompactorPress -> Meant for prefill scenario.
logger.info(f"Press {press_cls.__name__} is not supported, skipping test")
return
if hasattr(base_press, "__post_init_from_model__"):
base_press.__post_init_from_model__(pipe.model)

# Create DecodingPress with this base press
decoding_press = DecodingPress(base_press=base_press, compression_interval=3, target_size=48)
Expand Down
1 change: 1 addition & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TestPipelineFA2:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
@pytest.mark.parametrize("compression_ratio", [0.0, 0.2])
@pytest.mark.xfail(reason="Known issue not related to kvpress", strict=False)
def test_pipeline_fa2(self, kv_press_llama3_2_flash_attn_pipeline, compression_ratio): # noqa: F811
context = "This is a test article. It was written on 2022-01-01."
questions = ["Repeat the last sentence"]
Expand Down
Loading