Skip to content

Commit 49fdcf1

Browse files
committed
Adds cache smash code from the Project M codebase.
1 parent 03d93b4 commit 49fdcf1

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Utilities for KV smashing."""
2+
3+
from collections.abc import Iterable
4+
from functools import reduce
5+
from typing import Any
6+
7+
import torch
8+
from transformers import BatchEncoding, DynamicCache
9+
10+
TokenizedCacheIterleaving = Iterable[BatchEncoding | DynamicCache]
11+
LegacyCache = Any
12+
13+
14+
def legacy_cache_smash(a: LegacyCache, b: LegacyCache) -> LegacyCache:
15+
"""Concatenates two LegacyCache Ks and Vs along the time axis."""
16+
legacy_merged = tuple(
17+
(torch.cat([a[i][0], b[i][0]], dim=2), torch.cat([a[i][1], b[i][1]], dim=2))
18+
for i in range(len(a))
19+
)
20+
return legacy_merged
21+
22+
23+
def merge_dynamic_caches(caches: Iterable[DynamicCache]) -> DynamicCache:
24+
"""Merges two DynamicCache Ks and Vs along the time axis."""
25+
legacies = [c.to_legacy_cache() for c in caches]
26+
assert len(legacies) >= 1
27+
rv = DynamicCache.from_legacy_cache(reduce(legacy_cache_smash, legacies)) # type: ignore
28+
return rv # type: ignore
29+
30+
31+
def combine_representations(
32+
tokenizer, reps: Iterable[str | DynamicCache]
33+
) -> TokenizedCacheIterleaving:
34+
rv = []
35+
for rep in reps:
36+
if type(rep) is DynamicCache:
37+
rv.append(rep)
38+
else:
39+
rv.append(tokenizer(rep))
40+
return rv
41+
42+
43+
def tokens_to_legacy_cache(
44+
model, device: str, tokens_or_cache: BatchEncoding | DynamicCache
45+
) -> Iterable[LegacyCache]:
46+
"""Prefills and returns Ks and Vs as a LegacyCache."""
47+
if type(tokens_or_cache) is DynamicCache:
48+
return tokens_or_cache.to_legacy_cache()
49+
else:
50+
tokens = tokens_or_cache
51+
dc = DynamicCache()
52+
with torch.no_grad():
53+
dc = model(
54+
tokens["input_ids"].to(device), # type: ignore
55+
attention_mask=tokens["attention_mask"].to(device), # type: ignore
56+
past_key_values=dc,
57+
).past_key_values
58+
return dc.to_legacy_cache()

0 commit comments

Comments
 (0)