@@ -68,6 +68,14 @@ def __init__(self, head_specific, **kwargs) -> None:
68
68
self .kernel_size = 5
69
69
self .observation_len = 16
70
70
71
+ self .pool = torch .nn .AvgPool1d (
72
+ self .kernel_size ,
73
+ stride = 1 ,
74
+ padding = self .kernel_size // 2 ,
75
+ ceil_mode = False ,
76
+ count_include_pad = False ,
77
+ )
78
+
71
79
def is_compatible (self ) -> bool :
72
80
# Can only be used with head-specific KV-caches
73
81
return self .head_specific
@@ -78,19 +86,12 @@ def requires_attn(self) -> bool:
78
86
def __call__ (self , input_pos , k_val , v_val , attn ):
79
87
assert self .head_specific , "SnapKV can only be used with head-specific KV-caches, e.g., placing the same token in different locations across heads)."
80
88
81
- pool = torch .nn .AvgPool1d (
82
- self .kernel_size ,
83
- stride = 1 ,
84
- padding = self .kernel_size // 2 ,
85
- ceil_mode = False ,
86
- count_include_pad = False ,
87
- )
88
89
priority = attn [:, :, - self .observation_len :, :].mean (dim = 2 )
89
90
prev_shape = priority .shape
90
91
91
92
# We'll be returning the attention history so we need to keep a copy before it's modified
92
93
attn_history = priority .clone ()
93
- priority = pool (priority )
94
+ priority = self . pool (priority )
94
95
assert (
95
96
priority .shape == prev_shape
96
97
), f"Pooling operation should not change the dimension: { prev_shape } -> { priority .shape } "
0 commit comments