|
| 1 | +"""Utilities for KV smashing.""" |
| 2 | + |
| 3 | +from collections.abc import Iterable |
| 4 | +from functools import reduce |
| 5 | +from typing import Any |
| 6 | + |
| 7 | +import torch |
| 8 | +from transformers import BatchEncoding, DynamicCache |
| 9 | + |
| 10 | +TokenizedCacheIterleaving = Iterable[BatchEncoding | DynamicCache] |
| 11 | +LegacyCache = Any |
| 12 | + |
| 13 | + |
| 14 | +def legacy_cache_smash(a: LegacyCache, b: LegacyCache) -> LegacyCache: |
| 15 | + """Concatenates two LegacyCache Ks and Vs along the time axis.""" |
| 16 | + legacy_merged = tuple( |
| 17 | + (torch.cat([a[i][0], b[i][0]], dim=2), torch.cat([a[i][1], b[i][1]], dim=2)) |
| 18 | + for i in range(len(a)) |
| 19 | + ) |
| 20 | + return legacy_merged |
| 21 | + |
| 22 | + |
| 23 | +def merge_dynamic_caches(caches: Iterable[DynamicCache]) -> DynamicCache: |
| 24 | + """Merges two DynamicCache Ks and Vs along the time axis.""" |
| 25 | + legacies = [c.to_legacy_cache() for c in caches] |
| 26 | + assert len(legacies) >= 1 |
| 27 | + rv = DynamicCache.from_legacy_cache(reduce(legacy_cache_smash, legacies)) # type: ignore |
| 28 | + return rv # type: ignore |
| 29 | + |
| 30 | + |
| 31 | +def combine_representations( |
| 32 | + tokenizer, reps: Iterable[str | DynamicCache] |
| 33 | +) -> TokenizedCacheIterleaving: |
| 34 | + rv = [] |
| 35 | + for rep in reps: |
| 36 | + if type(rep) is DynamicCache: |
| 37 | + rv.append(rep) |
| 38 | + else: |
| 39 | + rv.append(tokenizer(rep)) |
| 40 | + return rv |
| 41 | + |
| 42 | + |
| 43 | +def tokens_to_legacy_cache( |
| 44 | + model, device: str, tokens_or_cache: BatchEncoding | DynamicCache |
| 45 | +) -> Iterable[LegacyCache]: |
| 46 | + """Prefills and returns Ks and Vs as a LegacyCache.""" |
| 47 | + if type(tokens_or_cache) is DynamicCache: |
| 48 | + return tokens_or_cache.to_legacy_cache() |
| 49 | + else: |
| 50 | + tokens = tokens_or_cache |
| 51 | + dc = DynamicCache() |
| 52 | + with torch.no_grad(): |
| 53 | + dc = model( |
| 54 | + tokens["input_ids"].to(device), # type: ignore |
| 55 | + attention_mask=tokens["attention_mask"].to(device), # type: ignore |
| 56 | + past_key_values=dc, |
| 57 | + ).past_key_values |
| 58 | + return dc.to_legacy_cache() |
0 commit comments