Skip to content

Commit 8306602

Browse files
authored
Add post_init_from_model to BasePress (#163)
1 parent dafafcb commit 8306602

18 files changed

+52
-104
lines changed

evaluation/evaluate.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,7 @@ def _load_and_prepare_dataset(self):
329329
# FinchPress uses a delimiter token to separate context and question
330330
# So we need to update the tokenizer and the model embeddings.
331331
logger.info("FinchPress detected, updating model and tokenizer with delimiter token.")
332-
self.press.update_model_and_tokenizer(
333-
self.pipeline.model, self.pipeline.tokenizer
334-
) # type: ignore[attr-defined]
332+
self.press.update_model_and_tokenizer(self.pipeline.model, self.pipeline.tokenizer) # type: ignore[attr-defined]
335333
df["context"] = df["context"] + self.press.delimiter_token # type: ignore[attr-defined, index]
336334

337335
if self.config.compress_questions:

kvpress/presses/adakv_press.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def __post_init__(self):
3939
assert isinstance(self.press, ScorerPress), "AdaKVPress requires a ScorerPress as input"
4040
assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in [0, 1]"
4141

42+
def post_init_from_model(self, model):
43+
self.press.post_init_from_model(model)
44+
4245
@property
4346
def compression_ratio(self):
4447
return self.press.compression_ratio

kvpress/presses/base_press.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ class BasePress:
4646
The compression is applied only during pre-filling (not during generation).
4747
"""
4848

49+
def post_init_from_model(self, model: PreTrainedModel):
50+
"""
51+
Optional method to initialize press parameters from the model
52+
"""
53+
pass
54+
4955
def compress(
5056
self,
5157
module: nn.Module,
@@ -179,6 +185,7 @@ def __call__(self, model: PreTrainedModel) -> Generator:
179185
if isinstance(model, Gemma3ForConditionalGeneration):
180186
logger.warning_once("Compression in Gemma3 is only applied to layer without sliding window attention")
181187

188+
self.post_init_from_model(model)
182189
hooks = []
183190
try:
184191
language_model = model.model.language_model if hasattr(model.model, "language_model") else model.model

kvpress/presses/block_press.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class BlockPress(BasePress):
3636
def __post_init__(self):
3737
assert isinstance(self.press, ScorerPress), "BlockPress requires a ScorerPress"
3838

39+
def post_init_from_model(self, model):
40+
self.press.post_init_from_model(model)
41+
3942
@property
4043
def compression_ratio(self):
4144
return self.press.compression_ratio

kvpress/presses/chunk_press.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class ChunkPress(BasePress):
3636
def __post_init__(self):
3737
assert isinstance(self.press, ScorerPress), "ChunkPress requires a ScorerPress as input"
3838

39+
def post_init_from_model(self, model):
40+
self.press.post_init_from_model(model)
41+
3942
@property
4043
def compression_ratio(self):
4144
return self.press.compression_ratio

kvpress/presses/chunkkv_press.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class ChunkKVPress(BasePress):
3737
def __post_init__(self):
3838
assert isinstance(self.press, ScorerPress), "ChunkKVPress requires a ScorerPress as input"
3939

40+
def post_init_from_model(self, model):
41+
self.press.post_init_from_model(model)
42+
4043
@property
4144
def compression_ratio(self):
4245
return self.press.compression_ratio

kvpress/presses/composed_press.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def __post_init__(self):
4949
isinstance(press, (AdaKVPress, KVzipPress)) for press in self.presses
5050
), "ComposedPress cannot contains AdaKVPress or KVzipPress"
5151

52+
def post_init_from_model(self, model):
53+
for press in self.presses:
54+
press.post_init_from_model(model)
55+
5256
def forward_hook(self, module, input, kwargs, output):
5357
retained_fraction = 1.0
5458
for press in self.presses:

kvpress/presses/criticalkv_press.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def __init__(self, press: ScorerPress, epsilon: float = 1e-4, first_stage_ratio:
4343
if isinstance(self.press, ExpectedAttentionPress) and self.press.use_vnorm:
4444
logger.warning("use_vnorm should be disabled for CriticalKVPress")
4545

46+
def post_init_from_model(self, model):
47+
self.press.post_init_from_model(model)
48+
4649
@property # type: ignore[misc]
4750
def compression_ratio(self): #
4851
return self.press.compression_ratio

kvpress/presses/decoding_press.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def __post_init__(self):
6060
f"This will be overridden by the decoding press."
6161
)
6262

63+
def post_init_from_model(self, model):
64+
self.base_press.post_init_from_model(model)
65+
6366
def compress(
6467
self,
6568
module: nn.Module,

kvpress/presses/duo_attention_press.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from contextlib import contextmanager
54
from dataclasses import dataclass, field
65
from io import StringIO
76

@@ -70,7 +69,7 @@ class DuoAttentionPress(BasePress):
7069
sink_size: int = field(init=False, default=None)
7170
streaming_mask: torch.Tensor = field(init=False, default=None)
7271

73-
def __post_init_from_model__(self, model):
72+
def post_init_from_model(self, model):
7473
"""
7574
Initialize sink_size, recent_size, and streaming_mask from a model
7675
"""
@@ -101,7 +100,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
101100
assert module.config._attn_implementation != "eager", "eager mode not supported"
102101
if self.streaming_mask is None:
103102
raise ValueError(
104-
"Streaming mask not initialized. Make sure to call __post_init_from_model__ to initialize this press."
103+
"Streaming mask not initialized. Make sure to call post_init_from_model to initialize this press."
105104
)
106105
k_len = keys.shape[2]
107106

@@ -141,12 +140,6 @@ def load_attention_pattern(model):
141140

142141
return config["sink_size"], config["recent_size"], head_scores
143142

144-
@contextmanager
145-
def __call__(self, model):
146-
self.__post_init_from_model__(model)
147-
with super().__call__(model):
148-
yield
149-
150143

151144
@cached(cache, key=lambda model, num_samples=50, q_len=500: (model.config.name_or_path, num_samples, q_len))
152145
def duo_attention_on_the_fly(model, num_samples=50, q_len=500):

0 commit comments

Comments
 (0)