Skip to content

Commit 9c9845b

Browse files
author
Griffin Adams
committed
Changes to FastGen.
1 parent 12a6435 commit 9c9845b

File tree

5 files changed

+40
-11
lines changed

5 files changed

+40
-11
lines changed

cache.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def compression_ratio(self, seq_len):
214214
"""
215215
# Final token isn't passed to cache so must -1 from seq_len
216216
n = seq_len - 1
217-
return ((n - min(self.cache_cts, self.max_cache_length)) / n).mean()
217+
return ((n - torch.clamp_max(self.cache_cts, self.max_cache_length)) / n).mean()
218218

219219
def return_kv_cache(self):
220220
# Truncate the cache based on number of insertions. It will be at the end since we prefill in-order.
@@ -651,10 +651,11 @@ def reset(self):
651651
self.attn_history_num.zero_()
652652
self.attn_history_denom.zero_()
653653
self.attn_counter.zero_()
654-
self.eviction_queue.zero_()
655-
# Start with an "empty queue" so that we can fill it up
656-
self.eviction_idx.fill_(self.drop_amount)
657-
assert self.queue_len() == 0
654+
if hasattr(self, "eviction_queue"):
655+
self.eviction_queue.zero_()
656+
# Start with an "empty queue" so that we can fill it up
657+
self.eviction_idx.fill_(self.drop_amount)
658+
assert self.queue_len() == 0
658659

659660
def queue_len(self):
660661
return self.drop_amount - self.eviction_idx
@@ -778,6 +779,7 @@ def __init__(
778779
**kwargs,
779780
):
780781
self.global_tokens = 0 # No global tokens for FastGen
782+
self.attn_record_freq = 1 # We record attention every step for FastGen
781783
super().__init__(
782784
max_batch_size,
783785
n_heads,
@@ -1072,6 +1074,9 @@ def profile_and_update(self, input_pos, input_ids, k_val, v_val, attn):
10721074
self.profile_attn_heads(input_pos, input_ids, attn)
10731075
)
10741076

1077+
# Show which strategies are selected
1078+
print([self.strategies[i] for i in self.cache_strategies.tolist()])
1079+
10751080
# If none of the heads selected a heavy hitter strategy, we don't need to track attention weights
10761081
self.requires_heavy_check = any(
10771082
["heavy" in KVCacheFastGen.strategies[i] for i in self.cache_strategies]

cache_configs/fastgen.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
cache_strategy: "fastgen"
2+
max_cache_length: [1.0] # [Fixed] Control compression ratio with min_recovery_frac
3+
prompt_compression_strategy: "snapkv" # Won't be used. Fastgen profiles attn and inserts directly.
4+
recent_window: 10 # Local window to consider for local strategies
5+
history_window_size: 400 # How many past steps to consider for attention importance calculation
6+
drop_amount: 0 # How frequently to calculate which tokens to evict (0 means we recalculate every step)
7+
attn_thresholding: False # Whether to threshold attention scores or record raw probabilities
8+
min_recovery_frac: 0.85 # Higher is less compression (0.85 means we choose the policy which compresses the most tokens AND recovers 85% of the full attention matrix)
9+
heavy_hitter_frac: 0.3 # Higher is less compression for the heavy hitter strategy

eval_multi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
HPARAMS = {
3232
"max_cache_length": [[8192], [4096], [2048], [1024], [512], [256], [128]],
33+
"min_recovery_frac": [0.5, 0.6, 0.7, 0.8, 0.9, 0.95],
3334
}
3435

3536

model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,14 @@ def reset_caches(self):
205205
def get_cache_stats(self, prompt_len, gen_len):
206206
stats = {}
207207
final_seq_len = prompt_len + gen_len
208+
crs = []
208209
for layer_idx, layer in enumerate(self.layers):
209-
stats[f"compression_ratio_{layer_idx}"] = (
210-
layer.attention.kv_cache.compression_ratio(
211-
seq_len=torch.tensor(final_seq_len)
212-
).item()
213-
)
210+
cr = layer.attention.kv_cache.compression_ratio(
211+
seq_len=torch.tensor(final_seq_len)
212+
).item()
213+
stats[f"compression_ratio_{layer_idx}"] = cr
214+
crs.append(cr)
215+
stats["compression_ratio_avg"] = sum(crs) / len(crs)
214216
return stats
215217

216218
def min_cache_length(self):

tokenizer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,19 @@ def __init__(self, model_path):
201201
]
202202

203203
def special_ids(self) -> List[List[int]]:
204-
return [[x] for x in self.tokenizer.special_token_ids]
204+
if hasattr(self.tokenizer, "special_token_ids"):
205+
return [[x] for x in self.tokenizer.special_token_ids]
206+
207+
# Its likely a tokenizer that has a special_tokens_map attribute
208+
special_tokens_ = list(self.tokenizer.special_tokens_map.values())
209+
special_tokens = []
210+
for t in special_tokens_:
211+
if type(t) == list:
212+
special_tokens.extend(t)
213+
else:
214+
special_tokens.append(t)
215+
special_tokens = list(set(special_tokens))
216+
return [[self.tokenizer.convert_tokens_to_ids(t)] for t in special_tokens]
205217

206218
def encode(self, text):
207219
return self.tokenizer.encode(text, add_special_tokens=False)

0 commit comments

Comments
 (0)