@@ -24,6 +24,37 @@ def is_compatible(self) -> bool:
24
24
pass
25
25
26
26
27
+ class PromptCompressorRandom (PromptCompressor ):
28
+ def __init__ (self , head_specific , ** kwargs ) -> None :
29
+ super ().__init__ (head_specific , ** kwargs )
30
+
31
+ def is_compatible (self ) -> bool :
32
+ # Can be used with any cache
33
+ return True
34
+
35
+ def requires_attn (self ) -> bool :
36
+ return False
37
+
38
+ def __call__ (self , input_pos , k_val , v_val ):
39
+ seq_len = input_pos .shape [0 ]
40
+ global_idxs = torch .arange (self .global_tokens , device = input_pos .device )
41
+ rand_idxs = (
42
+ (
43
+ self .global_tokens
44
+ + torch .randperm (seq_len - self .global_tokens , device = input_pos .device )[
45
+ : self .max_cache_length - self .global_tokens
46
+ ]
47
+ )
48
+ .sort ()
49
+ .values
50
+ )
51
+ keep_idxs = torch .cat ([global_idxs , rand_idxs ], dim = 0 )
52
+ assert len (keep_idxs ) == self .max_cache_length
53
+ k_val = k_val [:, :, keep_idxs ]
54
+ v_val = v_val [:, :, keep_idxs ]
55
+ return keep_idxs , k_val , v_val
56
+
57
+
27
58
class PromptCompressorRecentGlobal (PromptCompressor ):
28
59
def __init__ (self , head_specific , ** kwargs ) -> None :
29
60
super ().__init__ (head_specific , ** kwargs )
@@ -84,9 +115,10 @@ def requires_attn(self) -> bool:
84
115
return True
85
116
86
117
def __call__ (self , input_pos , k_val , v_val , attn ):
87
- assert self .head_specific , "SnapKV can only be used with head-specific KV-caches, e.g., placing the same token in different locations across heads)."
118
+ seq_len = input_pos .shape [0 ]
119
+ obs_len = min (self .observation_len , seq_len )
88
120
89
- priority = attn [:, :, - self . observation_len :, :].mean (dim = 2 )
121
+ priority = attn [:, :, - obs_len :, :].mean (dim = 2 )
90
122
prev_shape = priority .shape
91
123
92
124
# We'll be returning the attention history so we need to keep a copy before it's modified
@@ -95,8 +127,9 @@ def __call__(self, input_pos, k_val, v_val, attn):
95
127
assert (
96
128
priority .shape == prev_shape
97
129
), f"Pooling operation should not change the dimension: { prev_shape } -> { priority .shape } "
98
- priority [:, :, - self .observation_len :] = (
99
- 1.0 # Ensure the observation window is selected
130
+ priority [:, :, - obs_len :] = 1.0 # Ensure the observation window is selected
131
+ priority [:, :, : self .global_tokens ] = (
132
+ 1.0 # Ensure the global tokens are selected
100
133
)
101
134
keep_idxs = (
102
135
priority .topk (self .max_cache_length , dim = - 1 ).indices .sort (dim = - 1 ).values
@@ -152,5 +185,7 @@ def prompt_compressor_constructor(strategy):
152
185
return PromptCompressorSnapKV
153
186
elif strategy == "l2" :
154
187
return PromptCompressorL2
188
+ elif strategy == "random" :
189
+ return PromptCompressorRandom
155
190
else :
156
191
raise ValueError (f"Unknown prompt compression strategy: { strategy } " )
0 commit comments