Skip to content

Commit f5d640f

Browse files
authored
Refactor ObservedAttention (#166)
1 parent ea3a08e commit f5d640f

File tree

3 files changed

+3
-25
lines changed

3 files changed

+3
-25
lines changed

kvpress/pipeline.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from kvpress.presses.decoding_press import DecodingPress
1616
from kvpress.presses.finch_press import FinchPress
1717
from kvpress.presses.key_rerotation_press import KeyRerotationPress
18-
from kvpress.presses.observed_attention_press import ObservedAttentionPress
1918
from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
2019

2120
logger = logging.getLogger(__name__)
@@ -210,7 +209,6 @@ def _forward(
210209
self.model.model(
211210
input_ids=context_ids,
212211
past_key_values=cache,
213-
output_attentions=self.output_attentions(press),
214212
)
215213

216214
logger.debug(f"Context Length: {context_length}")
@@ -306,13 +304,6 @@ def generate_answer(
306304
answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True)
307305
return answer
308306

309-
def output_attentions(self, press: BasePress):
310-
if isinstance(press, ObservedAttentionPress):
311-
return True
312-
if hasattr(press, "press") and isinstance(press.press, ObservedAttentionPress):
313-
return True
314-
return False
315-
316307
def postprocess(self, model_outputs, single_question):
317308
if single_question:
318309
return {"answer": model_outputs[0]}

kvpress/presses/observed_attention_press.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,13 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44

5-
import logging
65
from dataclasses import dataclass
76

87
import torch
98
from torch import nn
109

1110
from kvpress.presses.scorer_press import ScorerPress
1211

13-
logger = logging.getLogger(__name__)
14-
1512

1613
@dataclass
1714
class ObservedAttentionPress(ScorerPress):
@@ -22,27 +19,17 @@ class ObservedAttentionPress(ScorerPress):
2219
forward pass. Score for each key-value pair is the average attention weight
2320
it receives from all query tokens.
2421
25-
Requires: output_attentions=True and attn_implementation="eager".
22+
Requires: attn_implementation="eager".
2623
2724
Related to H2O (https://arxiv.org/abs/2306.14048).
2825
2926
Parameters
3027
----------
3128
compression_ratio : float, default=0.0
3229
Fraction of key-value pairs to remove during compression.
33-
output_attentions : bool, default=True
34-
Whether to output the attention weights. Must be set True but we keep it for backward compatibility.
3530
"""
3631

3732
compression_ratio: float = 0.0
38-
output_attentions: bool = True
39-
40-
def __post_init__(self):
41-
if not self.output_attentions:
42-
# keep for backward compatibility, remove in version 1.0
43-
raise ValueError(
44-
"With transformers >= 4.54, " "ObservedAttentionPress will only work with output_attentions=True"
45-
)
4633

4734
def score(
4835
self,
@@ -53,7 +40,7 @@ def score(
5340
attentions: torch.Tensor,
5441
kwargs,
5542
) -> torch.Tensor:
56-
assert attentions is not None, 'Set output_attentions=True and attn_implementation="eager" to use this hook'
43+
assert attentions is not None, 'Set attn_implementation="eager" to use this hook'
5744
scores = attentions.sum(2)
5845
bsz, num_key_value_heads, n_tokens, _ = keys.shape
5946
n_tokens_in_sum = torch.arange(n_tokens, 0, -1).to(attentions.device, attentions.dtype)

tests/fixtures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def unit_test_model():
2121
@pytest.fixture(scope="session")
2222
def unit_test_model_output_attention():
2323
model = AutoModelForCausalLM.from_pretrained(
24-
"MaxJeblick/llama2-0b-unit-test", attn_implementation="eager", output_attentions=True
24+
"MaxJeblick/llama2-0b-unit-test", attn_implementation="eager"
2525
).eval()
2626
return model.to(get_device())
2727

0 commit comments

Comments
 (0)