22# SPDX-License-Identifier: Apache-2.0
33
44
5- import logging
65from dataclasses import dataclass
76
87import torch
98from torch import nn
109
1110from kvpress .presses .scorer_press import ScorerPress
1211
13- logger = logging .getLogger (__name__ )
14-
1512
1613@dataclass
1714class ObservedAttentionPress (ScorerPress ):
@@ -22,27 +19,17 @@ class ObservedAttentionPress(ScorerPress):
2219 forward pass. Score for each key-value pair is the average attention weight
2320 it receives from all query tokens.
2421
25- Requires: output_attentions=True and attn_implementation="eager".
22+ Requires: attn_implementation="eager".
2623
2724 Related to H2O (https://arxiv.org/abs/2306.14048).
2825
2926 Parameters
3027 ----------
3128 compression_ratio : float, default=0.0
3229 Fraction of key-value pairs to remove during compression.
33- output_attentions : bool, default=True
34- Whether to output the attention weights. Must be set True but we keep it for backward compatibility.
3530 """
3631
3732 compression_ratio : float = 0.0
38- output_attentions : bool = True
39-
40- def __post_init__ (self ):
41- if not self .output_attentions :
42- # keep for backward compatibility, remove in version 1.0
43- raise ValueError (
44- "With transformers >= 4.54, " "ObservedAttentionPress will only work with output_attentions=True"
45- )
4633
4734 def score (
4835 self ,
@@ -53,7 +40,7 @@ def score(
5340 attentions : torch .Tensor ,
5441 kwargs ,
5542 ) -> torch .Tensor :
56- assert attentions is not None , 'Set output_attentions=True and attn_implementation="eager" to use this hook'
43+ assert attentions is not None , 'Set attn_implementation="eager" to use this hook'
5744 scores = attentions .sum (2 )
5845 bsz , num_key_value_heads , n_tokens , _ = keys .shape
5946 n_tokens_in_sum = torch .arange (n_tokens , 0 , - 1 ).to (attentions .device , attentions .dtype )
0 commit comments