Skip to content

Commit 008175b

Browse files
author
Griffin Adams
committed
Add KVCacheAnalysis which computes attention loss.
1 parent 9c9845b commit 008175b

File tree

4 files changed

+216
-27
lines changed

4 files changed

+216
-27
lines changed

cache.py

Lines changed: 177 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import regex as re
12
from abc import ABC, abstractmethod
23
from typing import Tuple, Callable
34

@@ -19,10 +20,14 @@ def add_cache_arguments(parser: argparse.ArgumentParser):
1920
help="Cache size per layer. If len < n layers, the values are tiled. Must have len divisible by n layers. \
2021
If 0 < x <= 1, it is percent of |prompt| + max new tokens. Otherwise, if > 1, its the maximum size.",
2122
)
23+
strategies = ["full", "random", "window", "scissor", "l2", "fastgen"]
24+
debug_strategies = [f"debug_{strategy}" for strategy in strategies]
25+
strategies.extend(debug_strategies)
26+
2227
group.add_argument(
2328
"--cache_strategy",
2429
default="full",
25-
choices=["full", "random", "window", "scissor", "l2"],
30+
choices=strategies,
2631
)
2732

2833
# Dealing with Long Prompts
@@ -126,7 +131,7 @@ def create_window_attention_mask(seq_len, window_size, device):
126131
class KVCache(ABC, nn.Module):
127132
# Define which hyperparameters are relevant for the cache.
128133
# Override as needed for sub-classes.
129-
relevant_kwargs = ["max_cache_length", "global_tokens"]
134+
relevant_kwargs = ["max_cache_length", "max_seq_length", "global_tokens"]
130135

131136
def __init__(
132137
self,
@@ -208,6 +213,17 @@ def return_attn(self):
208213
"""
209214
return False
210215

216+
def compute_statistics(self, seq_len):
217+
"""
218+
Computes statistics about the cache.
219+
220+
Returns:
221+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The cache size, the number of tokens inserted, and the compression ratio.
222+
"""
223+
return {
224+
"compression_ratio": self.compression_ratio(seq_len).item(),
225+
}
226+
211227
def compression_ratio(self, seq_len):
212228
"""
213229
Returns the compression ratio of the cache.
@@ -276,6 +292,24 @@ def compress_prompt(
276292
# Yet we return the un-compressed KV since during pre-fill we compute full causal attention.
277293
return k_val, v_val, mask, new_callback
278294

295+
def attn_history_callback(self) -> Callable | None:
296+
"""
297+
Returns a callback to update the attention history.
298+
299+
Returns None if attention is not needed
300+
"""
301+
return (
302+
{
303+
"func": lambda input_pos,
304+
input_ids,
305+
k_val,
306+
v_val,
307+
attn: self.update_attn_history(attn)
308+
}
309+
if self.return_attn()
310+
else None
311+
)
312+
279313
def update(self, input_pos, k_val, v_val, input_ids=None):
280314
"""
281315
Updates the cache with the given input positions, keys, and values.
@@ -424,7 +458,7 @@ def mark_global_tokens(self, num_total_insertions: int) -> bool:
424458
), "This cache does not have global tokens so we cannot mark them."
425459
# Give self.pos an highest possible position value for global tokens so that they are not replaced
426460
num_to_mark = min(self.global_tokens, num_total_insertions)
427-
self.pos[:, :, :num_to_mark] = self.max_cache_length
461+
self.pos[:, :, :num_to_mark] = self.max_seq_length
428462
return num_to_mark == self.global_tokens
429463

430464

@@ -448,6 +482,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
448482
class KVCacheRandom(KVCache):
449483
relevant_kwargs = [
450484
"max_cache_length",
485+
"max_seq_length",
451486
"global_tokens",
452487
"prompt_compression_strategy",
453488
]
@@ -475,6 +510,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
475510
class KVCacheWindow(KVCache):
476511
relevant_kwargs = [
477512
"max_cache_length",
513+
"max_seq_length",
478514
"global_tokens",
479515
"prompt_compression_strategy",
480516
# NB: "recent_window" is ignored as a relevant kwarg. It is fixed to self.max_cache_length - self.global_tokens.
@@ -520,6 +556,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
520556
class KVCacheL2(KVCacheWindow):
521557
relevant_kwargs = [
522558
"max_cache_length",
559+
"max_seq_length",
523560
"global_tokens",
524561
"recent_window",
525562
"prompt_compression_strategy",
@@ -569,6 +606,7 @@ def update_attn_history(self, attn):
569606
class KVCacheScissorhands(KVCacheWindow):
570607
relevant_kwargs = [
571608
"max_cache_length",
609+
"max_seq_length",
572610
"global_tokens",
573611
"history_window_size",
574612
"drop_amount",
@@ -752,6 +790,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
752790
class KVCacheFastGen(KVCacheScissorhands):
753791
relevant_kwargs = [
754792
"max_cache_length",
793+
"max_seq_length",
755794
"history_window_size",
756795
"recent_window",
757796
"attn_thresholding",
@@ -1116,18 +1155,147 @@ def profile_and_update(self, input_pos, input_ids, k_val, v_val, attn):
11161155
self.update_attn_history(cum_attn)
11171156

11181157

1158+
class KVCacheAnalysis(KVCache):
1159+
relevant_kwargs = [
1160+
"max_cache_length",
1161+
"history_window_size",
1162+
"recent_window",
1163+
"attn_thresholding",
1164+
"token_ids",
1165+
"prompt_compression_strategy",
1166+
"min_recovery_frac",
1167+
"heavy_hitter_frac",
1168+
"global_tokens",
1169+
"drop_amount",
1170+
"prompt_compression_strategy",
1171+
"attn_record_freq",
1172+
"max_seq_length",
1173+
]
1174+
1175+
def __init__(
1176+
self,
1177+
max_batch_size,
1178+
n_heads,
1179+
head_dim,
1180+
dtype=torch.bfloat16,
1181+
cache_strategy="scissor",
1182+
**kwargs,
1183+
):
1184+
# Never any prompt compression for full cache
1185+
full_kwargs = {
1186+
"prompt_compression_strategy": None,
1187+
"global_tokens": 0,
1188+
"max_cache_length": kwargs["max_seq_length"],
1189+
"max_seq_length": kwargs["max_seq_length"],
1190+
}
1191+
super().__init__(
1192+
max_batch_size, n_heads, head_dim, dtype, head_specific=False, **full_kwargs
1193+
)
1194+
1195+
# Initialize the compressed cache we want to analyze.
1196+
self.compressed = get_cache_constructor(cache_strategy=cache_strategy)[0](
1197+
max_batch_size,
1198+
n_heads,
1199+
head_dim,
1200+
dtype,
1201+
**kwargs,
1202+
)
1203+
1204+
self.register_buffer(
1205+
"attention_losses",
1206+
torch.full((self.max_seq_length,), fill_value=-1, dtype=dtype),
1207+
)
1208+
1209+
def return_attn(self):
1210+
return self.compressed.return_attn()
1211+
1212+
def update(self, input_pos, k_val, v_val, input_ids=None):
1213+
k, v, mask, _ = super().update(input_pos, k_val, v_val, input_ids=input_ids)
1214+
_, _, _, attn_callback = self.compressed.update(
1215+
input_pos, k_val, v_val, input_ids=input_ids
1216+
)
1217+
1218+
if attn_callback is not None and input_pos.shape[-1] == 1:
1219+
# This is ugly but we need to re-write callback to call this class's update_attn_history not the compressed
1220+
# This is because we need to filter the attention weights to only the tokens in the compressed cache first.
1221+
attn_callback = self.attn_history_callback()
1222+
assert attn_callback is not None
1223+
1224+
return k, v, mask, attn_callback
1225+
1226+
def _update(self, input_pos, k_val, v_val, input_ids=None):
1227+
# input_pos: [S], k_val: [B, H, S, D]
1228+
self.fill_contiguous(input_pos, k_val, v_val)
1229+
return input_pos.shape[-1]
1230+
1231+
def reset(self):
1232+
super().reset()
1233+
self.compressed.reset()
1234+
self.attention_losses.fill_(-1)
1235+
1236+
def update_attn_history(self, attn: torch.Tensor):
1237+
indices = self.compressed.pos.clone().long()
1238+
1239+
# Global tokens will have been set to max seq length
1240+
# We need to set them back to actual global tokens
1241+
indices[:, :, : self.compressed.global_tokens] = (
1242+
torch.arange(self.compressed.global_tokens, device=indices.device)
1243+
.view(1, 1, -1)
1244+
.expand(1, indices.shape[1], -1)
1245+
)
1246+
indices = indices[:, :, : min(indices.shape[-1], attn.shape[-1])]
1247+
attn_compressed = attn.squeeze(2).gather(2, indices).unsqueeze(2)
1248+
self.compressed.update_attn_history(attn_compressed)
1249+
1250+
attn_loss = (1 - attn_compressed.sum(dim=-1)).mean()
1251+
insert_idx = torch.where(self.attention_losses == -1)[0][0]
1252+
self.attention_losses[insert_idx] = attn_loss
1253+
1254+
def compute_statistics(self, seq_len):
1255+
"""
1256+
Computes statistics about the cache.
1257+
1258+
Returns:
1259+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The cache size, the number of tokens inserted, and the compression ratio.
1260+
"""
1261+
stats = super().compute_statistics(seq_len)
1262+
cutoff = torch.where(self.attention_losses == -1)[0]
1263+
if len(cutoff) > 0:
1264+
cutoff = cutoff[0]
1265+
else:
1266+
cutoff = len(self.attention_losses)
1267+
stats["attention_loss"] = (self.attention_losses[:cutoff].sum() / cutoff).item()
1268+
return stats
1269+
1270+
11191271
def get_cache_constructor(cache_strategy):
1272+
relevant_kwargs = None
11201273
if cache_strategy == "full":
1121-
return KVCacheFull
1274+
cls = KVCacheFull
11221275
elif cache_strategy == "l2":
1123-
return KVCacheL2
1276+
cls = KVCacheL2
11241277
elif cache_strategy == "random":
1125-
return KVCacheRandom
1278+
cls = KVCacheRandom
11261279
elif cache_strategy == "window":
1127-
return KVCacheWindow
1280+
cls = KVCacheWindow
11281281
elif cache_strategy == "scissor":
1129-
return KVCacheScissorhands
1282+
cls = KVCacheScissorhands
11301283
elif cache_strategy == "fastgen":
1131-
return KVCacheFastGen
1284+
cls = KVCacheFastGen
1285+
elif cache_strategy.startswith("debug"):
1286+
cache_strategy = re.sub(r"debug_+", "", cache_strategy).strip()
1287+
relevant_kwargs = get_cache_constructor(cache_strategy)[1]
1288+
cls = (
1289+
lambda max_batch_size, n_heads, head_dim, dtype, **kwargs: KVCacheAnalysis(
1290+
max_batch_size,
1291+
n_heads,
1292+
head_dim,
1293+
dtype,
1294+
cache_strategy=cache_strategy,
1295+
**kwargs,
1296+
)
1297+
)
11321298
else:
11331299
raise ValueError(f"Invalid cache strategy: {cache_strategy}")
1300+
1301+
return cls, relevant_kwargs or cls.relevant_kwargs

eval.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import yaml
99
import argparse
1010
import json
11+
import regex as re
1112
import contextlib
1213
import shutil
1314
import pandas as pd
@@ -52,7 +53,16 @@
5253

5354

5455
def args_to_str(args):
55-
RELEVANT_CACHE_KWARGS = get_cache_constructor(args.cache_strategy).relevant_kwargs
56+
if "debug" in args.cache_strategy:
57+
debug_suffix = "__debug"
58+
cache_strategy = re.sub(r"debug_+", "", args.cache_strategy).strip()
59+
RELEVANT_CACHE_KWARGS = get_cache_constructor(
60+
args.cache_strategy.replace("debug_", "")
61+
)[1]
62+
else:
63+
cache_strategy = args.cache_strategy
64+
debug_suffix = ""
65+
RELEVANT_CACHE_KWARGS = get_cache_constructor(cache_strategy)[1]
5666

5767
def process_num(n):
5868
# Return integer floats as "1" not 1.0
@@ -61,16 +71,19 @@ def process_num(n):
6171
return int(n)
6272
return n
6373

64-
return "__".join(
65-
sorted(
66-
[
67-
f"{k}=" + ",".join([str(process_num(m)) for m in v])
68-
if type(v) == list
69-
else f"{k}={process_num(v)}"
70-
for k, v in vars(args).items()
71-
if k in RELEVANT_CACHE_KWARGS
72-
]
74+
return (
75+
"__".join(
76+
sorted(
77+
[
78+
f"{k}=" + ",".join([str(process_num(m)) for m in v])
79+
if type(v) == list
80+
else f"{k}={process_num(v)}"
81+
for k, v in vars(args).items()
82+
if k in RELEVANT_CACHE_KWARGS
83+
]
84+
)
7385
)
86+
+ debug_suffix
7487
)
7588

7689

generation_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def setup_caches(
192192
max_seq_length: int,
193193
cache_kwargs: dict = None,
194194
):
195+
cache_kwargs["max_seq_length"] = max_seq_length
195196
# Normalize max_cache_length to absolute cache length if provided as a fraction of the max seq sequence length
196197
cache_kwargs["max_cache_length"] = list(
197198
map(

model.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
from dataclasses import dataclass
7+
from collections import defaultdict
78
from typing import Optional
89

910
import torch
@@ -176,12 +177,14 @@ def setup_caches(self, **kwargs):
176177
elif hasattr(self.output, "scales_and_zeros"):
177178
dtype = self.output.scales_and_zeros.dtype
178179
for layer_idx, b in enumerate(self.layers):
179-
cache_constructor = get_cache_constructor(cache_strategy=cache_strategy)
180+
cache_constructor, relevant_kwargs = get_cache_constructor(
181+
cache_strategy=cache_strategy
182+
)
180183
# Only pass in the kwargs we need for the cache we chose (useful especially for debugging)
181184
layerwise_keys = {"max_cache_length", "drop_amount"}
182185
layer_kwargs = {
183186
k: kwargs[k][layer_idx] if k in layerwise_keys else kwargs[k]
184-
for k in cache_constructor.relevant_kwargs
187+
for k in relevant_kwargs
185188
}
186189
b.attention.kv_cache = cache_constructor(
187190
self.max_batch_size,
@@ -205,14 +208,18 @@ def reset_caches(self):
205208
def get_cache_stats(self, prompt_len, gen_len):
206209
stats = {}
207210
final_seq_len = prompt_len + gen_len
208-
crs = []
211+
avgs = defaultdict(list)
209212
for layer_idx, layer in enumerate(self.layers):
210-
cr = layer.attention.kv_cache.compression_ratio(
213+
stat = layer.attention.kv_cache.compute_statistics(
211214
seq_len=torch.tensor(final_seq_len)
212-
).item()
213-
stats[f"compression_ratio_{layer_idx}"] = cr
214-
crs.append(cr)
215-
stats["compression_ratio_avg"] = sum(crs) / len(crs)
215+
)
216+
for k, v in stat.items():
217+
stats[f"{k}_{layer_idx}"] = v
218+
avgs[k].append(v)
219+
220+
for k, v in avgs.items():
221+
stats[f"{k}_avg"] = sum(v) / len(v)
222+
216223
return stats
217224

218225
def min_cache_length(self):

0 commit comments

Comments
 (0)