1
+ import regex as re
1
2
from abc import ABC , abstractmethod
2
3
from typing import Tuple , Callable
3
4
@@ -19,10 +20,14 @@ def add_cache_arguments(parser: argparse.ArgumentParser):
19
20
help = "Cache size per layer. If len < n layers, the values are tiled. Must have len divisible by n layers. \
20
21
If 0 < x <= 1, it is percent of |prompt| + max new tokens. Otherwise, if > 1, its the maximum size." ,
21
22
)
23
+ strategies = ["full" , "random" , "window" , "scissor" , "l2" , "fastgen" ]
24
+ debug_strategies = [f"debug_{ strategy } " for strategy in strategies ]
25
+ strategies .extend (debug_strategies )
26
+
22
27
group .add_argument (
23
28
"--cache_strategy" ,
24
29
default = "full" ,
25
- choices = [ "full" , "random" , "window" , "scissor" , "l2" ] ,
30
+ choices = strategies ,
26
31
)
27
32
28
33
# Dealing with Long Prompts
@@ -126,7 +131,7 @@ def create_window_attention_mask(seq_len, window_size, device):
126
131
class KVCache (ABC , nn .Module ):
127
132
# Define which hyperparameters are relevant for the cache.
128
133
# Override as needed for sub-classes.
129
- relevant_kwargs = ["max_cache_length" , "global_tokens" ]
134
+ relevant_kwargs = ["max_cache_length" , "max_seq_length" , " global_tokens" ]
130
135
131
136
def __init__ (
132
137
self ,
@@ -208,6 +213,17 @@ def return_attn(self):
208
213
"""
209
214
return False
210
215
216
+ def compute_statistics (self , seq_len ):
217
+ """
218
+ Computes statistics about the cache.
219
+
220
+ Returns:
221
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The cache size, the number of tokens inserted, and the compression ratio.
222
+ """
223
+ return {
224
+ "compression_ratio" : self .compression_ratio (seq_len ).item (),
225
+ }
226
+
211
227
def compression_ratio (self , seq_len ):
212
228
"""
213
229
Returns the compression ratio of the cache.
@@ -276,6 +292,24 @@ def compress_prompt(
276
292
# Yet we return the un-compressed KV since during pre-fill we compute full causal attention.
277
293
return k_val , v_val , mask , new_callback
278
294
295
+ def attn_history_callback (self ) -> Callable | None :
296
+ """
297
+ Returns a callback to update the attention history.
298
+
299
+ Returns None if attention is not needed
300
+ """
301
+ return (
302
+ {
303
+ "func" : lambda input_pos ,
304
+ input_ids ,
305
+ k_val ,
306
+ v_val ,
307
+ attn : self .update_attn_history (attn )
308
+ }
309
+ if self .return_attn ()
310
+ else None
311
+ )
312
+
279
313
def update (self , input_pos , k_val , v_val , input_ids = None ):
280
314
"""
281
315
Updates the cache with the given input positions, keys, and values.
@@ -424,7 +458,7 @@ def mark_global_tokens(self, num_total_insertions: int) -> bool:
424
458
), "This cache does not have global tokens so we cannot mark them."
425
459
# Give self.pos an highest possible position value for global tokens so that they are not replaced
426
460
num_to_mark = min (self .global_tokens , num_total_insertions )
427
- self .pos [:, :, :num_to_mark ] = self .max_cache_length
461
+ self .pos [:, :, :num_to_mark ] = self .max_seq_length
428
462
return num_to_mark == self .global_tokens
429
463
430
464
@@ -448,6 +482,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
448
482
class KVCacheRandom (KVCache ):
449
483
relevant_kwargs = [
450
484
"max_cache_length" ,
485
+ "max_seq_length" ,
451
486
"global_tokens" ,
452
487
"prompt_compression_strategy" ,
453
488
]
@@ -475,6 +510,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
475
510
class KVCacheWindow (KVCache ):
476
511
relevant_kwargs = [
477
512
"max_cache_length" ,
513
+ "max_seq_length" ,
478
514
"global_tokens" ,
479
515
"prompt_compression_strategy" ,
480
516
# NB: "recent_window" is ignored as a relevant kwarg. It is fixed to self.max_cache_length - self.global_tokens.
@@ -520,6 +556,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
520
556
class KVCacheL2 (KVCacheWindow ):
521
557
relevant_kwargs = [
522
558
"max_cache_length" ,
559
+ "max_seq_length" ,
523
560
"global_tokens" ,
524
561
"recent_window" ,
525
562
"prompt_compression_strategy" ,
@@ -569,6 +606,7 @@ def update_attn_history(self, attn):
569
606
class KVCacheScissorhands (KVCacheWindow ):
570
607
relevant_kwargs = [
571
608
"max_cache_length" ,
609
+ "max_seq_length" ,
572
610
"global_tokens" ,
573
611
"history_window_size" ,
574
612
"drop_amount" ,
@@ -752,6 +790,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
752
790
class KVCacheFastGen (KVCacheScissorhands ):
753
791
relevant_kwargs = [
754
792
"max_cache_length" ,
793
+ "max_seq_length" ,
755
794
"history_window_size" ,
756
795
"recent_window" ,
757
796
"attn_thresholding" ,
@@ -1116,18 +1155,147 @@ def profile_and_update(self, input_pos, input_ids, k_val, v_val, attn):
1116
1155
self .update_attn_history (cum_attn )
1117
1156
1118
1157
1158
+ class KVCacheAnalysis (KVCache ):
1159
+ relevant_kwargs = [
1160
+ "max_cache_length" ,
1161
+ "history_window_size" ,
1162
+ "recent_window" ,
1163
+ "attn_thresholding" ,
1164
+ "token_ids" ,
1165
+ "prompt_compression_strategy" ,
1166
+ "min_recovery_frac" ,
1167
+ "heavy_hitter_frac" ,
1168
+ "global_tokens" ,
1169
+ "drop_amount" ,
1170
+ "prompt_compression_strategy" ,
1171
+ "attn_record_freq" ,
1172
+ "max_seq_length" ,
1173
+ ]
1174
+
1175
+ def __init__ (
1176
+ self ,
1177
+ max_batch_size ,
1178
+ n_heads ,
1179
+ head_dim ,
1180
+ dtype = torch .bfloat16 ,
1181
+ cache_strategy = "scissor" ,
1182
+ ** kwargs ,
1183
+ ):
1184
+ # Never any prompt compression for full cache
1185
+ full_kwargs = {
1186
+ "prompt_compression_strategy" : None ,
1187
+ "global_tokens" : 0 ,
1188
+ "max_cache_length" : kwargs ["max_seq_length" ],
1189
+ "max_seq_length" : kwargs ["max_seq_length" ],
1190
+ }
1191
+ super ().__init__ (
1192
+ max_batch_size , n_heads , head_dim , dtype , head_specific = False , ** full_kwargs
1193
+ )
1194
+
1195
+ # Initialize the compressed cache we want to analyze.
1196
+ self .compressed = get_cache_constructor (cache_strategy = cache_strategy )[0 ](
1197
+ max_batch_size ,
1198
+ n_heads ,
1199
+ head_dim ,
1200
+ dtype ,
1201
+ ** kwargs ,
1202
+ )
1203
+
1204
+ self .register_buffer (
1205
+ "attention_losses" ,
1206
+ torch .full ((self .max_seq_length ,), fill_value = - 1 , dtype = dtype ),
1207
+ )
1208
+
1209
+ def return_attn (self ):
1210
+ return self .compressed .return_attn ()
1211
+
1212
+ def update (self , input_pos , k_val , v_val , input_ids = None ):
1213
+ k , v , mask , _ = super ().update (input_pos , k_val , v_val , input_ids = input_ids )
1214
+ _ , _ , _ , attn_callback = self .compressed .update (
1215
+ input_pos , k_val , v_val , input_ids = input_ids
1216
+ )
1217
+
1218
+ if attn_callback is not None and input_pos .shape [- 1 ] == 1 :
1219
+ # This is ugly but we need to re-write callback to call this class's update_attn_history not the compressed
1220
+ # This is because we need to filter the attention weights to only the tokens in the compressed cache first.
1221
+ attn_callback = self .attn_history_callback ()
1222
+ assert attn_callback is not None
1223
+
1224
+ return k , v , mask , attn_callback
1225
+
1226
+ def _update (self , input_pos , k_val , v_val , input_ids = None ):
1227
+ # input_pos: [S], k_val: [B, H, S, D]
1228
+ self .fill_contiguous (input_pos , k_val , v_val )
1229
+ return input_pos .shape [- 1 ]
1230
+
1231
+ def reset (self ):
1232
+ super ().reset ()
1233
+ self .compressed .reset ()
1234
+ self .attention_losses .fill_ (- 1 )
1235
+
1236
+ def update_attn_history (self , attn : torch .Tensor ):
1237
+ indices = self .compressed .pos .clone ().long ()
1238
+
1239
+ # Global tokens will have been set to max seq length
1240
+ # We need to set them back to actual global tokens
1241
+ indices [:, :, : self .compressed .global_tokens ] = (
1242
+ torch .arange (self .compressed .global_tokens , device = indices .device )
1243
+ .view (1 , 1 , - 1 )
1244
+ .expand (1 , indices .shape [1 ], - 1 )
1245
+ )
1246
+ indices = indices [:, :, : min (indices .shape [- 1 ], attn .shape [- 1 ])]
1247
+ attn_compressed = attn .squeeze (2 ).gather (2 , indices ).unsqueeze (2 )
1248
+ self .compressed .update_attn_history (attn_compressed )
1249
+
1250
+ attn_loss = (1 - attn_compressed .sum (dim = - 1 )).mean ()
1251
+ insert_idx = torch .where (self .attention_losses == - 1 )[0 ][0 ]
1252
+ self .attention_losses [insert_idx ] = attn_loss
1253
+
1254
+ def compute_statistics (self , seq_len ):
1255
+ """
1256
+ Computes statistics about the cache.
1257
+
1258
+ Returns:
1259
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The cache size, the number of tokens inserted, and the compression ratio.
1260
+ """
1261
+ stats = super ().compute_statistics (seq_len )
1262
+ cutoff = torch .where (self .attention_losses == - 1 )[0 ]
1263
+ if len (cutoff ) > 0 :
1264
+ cutoff = cutoff [0 ]
1265
+ else :
1266
+ cutoff = len (self .attention_losses )
1267
+ stats ["attention_loss" ] = (self .attention_losses [:cutoff ].sum () / cutoff ).item ()
1268
+ return stats
1269
+
1270
+
1119
1271
def get_cache_constructor (cache_strategy ):
1272
+ relevant_kwargs = None
1120
1273
if cache_strategy == "full" :
1121
- return KVCacheFull
1274
+ cls = KVCacheFull
1122
1275
elif cache_strategy == "l2" :
1123
- return KVCacheL2
1276
+ cls = KVCacheL2
1124
1277
elif cache_strategy == "random" :
1125
- return KVCacheRandom
1278
+ cls = KVCacheRandom
1126
1279
elif cache_strategy == "window" :
1127
- return KVCacheWindow
1280
+ cls = KVCacheWindow
1128
1281
elif cache_strategy == "scissor" :
1129
- return KVCacheScissorhands
1282
+ cls = KVCacheScissorhands
1130
1283
elif cache_strategy == "fastgen" :
1131
- return KVCacheFastGen
1284
+ cls = KVCacheFastGen
1285
+ elif cache_strategy .startswith ("debug" ):
1286
+ cache_strategy = re .sub (r"debug_+" , "" , cache_strategy ).strip ()
1287
+ relevant_kwargs = get_cache_constructor (cache_strategy )[1 ]
1288
+ cls = (
1289
+ lambda max_batch_size , n_heads , head_dim , dtype , ** kwargs : KVCacheAnalysis (
1290
+ max_batch_size ,
1291
+ n_heads ,
1292
+ head_dim ,
1293
+ dtype ,
1294
+ cache_strategy = cache_strategy ,
1295
+ ** kwargs ,
1296
+ )
1297
+ )
1132
1298
else :
1133
1299
raise ValueError (f"Invalid cache strategy: { cache_strategy } " )
1300
+
1301
+ return cls , relevant_kwargs or cls .relevant_kwargs
0 commit comments