Skip to content

Commit 12be67c

Browse files
author
Griffin Adams
committed
Update FastGen to use new attention loss calculation.
1 parent 008175b commit 12be67c

File tree

4 files changed

+21
-31
lines changed

4 files changed

+21
-31
lines changed

cache.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def add_cache_arguments(parser: argparse.ArgumentParser):
5757
group.add_argument(
5858
"--recent_window", # NB: for KVCacheWindow, recent_window is implicitly set to self.max_cache_length - self.global_tokens.
5959
default=10, # 10 is default specified in ScissorHands paper ("r" in Algorithm 2).
60-
type=int,
60+
type=float, # If < 1, it is a fraction of max_cache_length.
6161
help="The number of recently generated tokens to always spare from eviction.",
6262
)
6363

@@ -848,14 +848,10 @@ def __init__(
848848
kv_mask_shape = (max_batch_size, n_heads, 1, self.max_cache_length)
849849
self.register_buffer("mask", torch.zeros(kv_mask_shape, dtype=torch.bool))
850850

851-
self.epsilon = (
852-
1e-4 # Max difference between attention probs to be considered equivalent.
853-
)
854-
855851
# NB: Kwargs are sdpa attention kwargs, not the kwargs for the "func"
856852
self.prefill_attn_callback = {
857853
"func": self.profile_and_update,
858-
"kwargs": {"return_attn_logits": True},
854+
"kwargs": {"return_attn_logits": False},
859855
}
860856

861857
def return_attn(self):
@@ -1006,25 +1002,6 @@ def build_punc_ids_mask(self, input_ids):
10061002
punc_ids_mask = torch.isin(input_ids, self.punc_ids)
10071003
return punc_ids_mask
10081004

1009-
def compute_remasked_attn(self, attn, masks):
1010-
"""
1011-
Compute the attention with the masks applied. Mask should be true for tokens we want to keep.
1012-
"""
1013-
num_masks = masks.shape[0]
1014-
attn = attn.expand(num_masks, -1, -1, -1)
1015-
return torch.softmax(attn.masked_fill(~masks, float("-inf")), dim=-1)
1016-
1017-
def recovery_percent(self, attn, compressed_attn):
1018-
assert (
1019-
attn.shape[-2] == attn.shape[-1]
1020-
), "Attention matrix expected to be square for profiling."
1021-
num_causal = attn.shape[-1] * (attn.shape[-1] + 1) // 2
1022-
num_padding = num_causal - attn.shape[-1] # Subtract the trace
1023-
return (
1024-
(torch.abs(attn - compressed_attn) < self.epsilon).sum(dim=-1).sum(dim=-1)
1025-
- num_padding
1026-
) / num_causal
1027-
10281005
def profile_attn_heads(self, input_pos, input_ids, attn):
10291006
input_ids = input_ids.squeeze(0)
10301007
seq_len = input_ids.shape[-1]
@@ -1078,10 +1055,9 @@ def profile_attn_heads(self, input_pos, input_ids, attn):
10781055
]
10791056
)
10801057

1081-
compressed_attns = self.compute_remasked_attn(attn, masks)
1082-
compressed_scores = self.recovery_percent(
1083-
compressed_attns, compressed_attns[-1]
1084-
)
1058+
attn_rep = attn.expand(masks.shape[0], -1, -1, -1)
1059+
1060+
compressed_scores = attn_rep.masked_fill(~masks, 0).sum(dim=-1).mean(dim=-1)
10851061

10861062
# For each column, return the first row which has cost >= min_recovery_frac
10871063
cache_strategies = (

cache_configs/fastgen.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ history_window_size: 400 # How many past steps to consider for attention import
66
drop_amount: 0 # How frequently to calculate which tokens to evict (0 means we recalculate every step)
77
attn_thresholding: False # Whether to threshold attention scores or record raw probabilities
88
min_recovery_frac: 0.85 # Higher is less compression (0.85 means we choose the policy which compresses the most tokens AND recovers 85% of the full attention matrix)
9-
heavy_hitter_frac: 0.3 # Higher is less compression for the heavy hitter strategy
9+
heavy_hitter_frac: 0.3 # Higher is less compression for the heavy hitter strategy
10+
recent_window: 0.3

generation_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def setup_caches(
200200
cache_kwargs["max_cache_length"],
201201
)
202202
)
203+
203204
assert (
204205
model.config.n_layer % len(cache_kwargs["max_cache_length"]) == 0
205206
), f'max_cache_length ({len(cache_kwargs["max_cache_length"])}) must be a factor of {model.config.n_layer} layers.'
@@ -209,6 +210,18 @@ def setup_caches(
209210
item for item in cache_kwargs["max_cache_length"] for _ in range(tile_size)
210211
]
211212

213+
if type(cache_kwargs["recent_window"]) != list:
214+
if cache_kwargs["recent_window"] <= 1:
215+
cache_kwargs["recent_window"] = [
216+
max(1, int(cache_kwargs["recent_window"] * l))
217+
for l in cache_kwargs["max_cache_length"]
218+
]
219+
else:
220+
cache_kwargs["recent_window"] = [
221+
max(1, min(cache_kwargs["recent_window"], l))
222+
for l in cache_kwargs["max_cache_length"]
223+
]
224+
212225
# Gets called twice when model is wrapped in torch.compile which causes an error without the if statement
213226
if type(cache_kwargs["drop_amount"]) != list:
214227
cache_kwargs["drop_amount"] = [

model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def setup_caches(self, **kwargs):
181181
cache_strategy=cache_strategy
182182
)
183183
# Only pass in the kwargs we need for the cache we chose (useful especially for debugging)
184-
layerwise_keys = {"max_cache_length", "drop_amount"}
184+
layerwise_keys = {"max_cache_length", "drop_amount", "recent_window"}
185185
layer_kwargs = {
186186
k: kwargs[k][layer_idx] if k in layerwise_keys else kwargs[k]
187187
for k in relevant_kwargs

0 commit comments

Comments
 (0)