Skip to content

Commit a648405

Browse files
committed
Adds a crystallization of the kv smash code
And the way it fits into a model that uses apply_chat_tempalte or any other parser/renderer. Note that there's still a bug entailed by the chance that there are also substrings which "hit" on the cached contents. We don't anticipate this happens often in practice because of how KV cache smashing should typically be used, but it's something we need to address by introducing the use of sentinel values, or indexing string machines, or something else along those lines. no-verify commit because the point of this code is documentation.
1 parent fab35d9 commit a648405

File tree

2 files changed

+165
-0
lines changed

2 files changed

+165
-0
lines changed

docs/kv_smash/kv_with_chat.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
3+
from mellea.backends.huggingface import LocalHFBackend
4+
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches
5+
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
6+
from mellea.stdlib.base import CBlock, LinearContext
7+
from mellea.stdlib.chat import Message
8+
9+
backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B)
10+
11+
model = backend._model
12+
tokenizer = backend._tokenizer
13+
device = backend._device
14+
15+
16+
KV_CACHE: dict[str, DynamicCache] = dict()
17+
18+
19+
def cache(s: str, store=True) -> DynamicCache:
20+
toks = tokenizer(s, return_tensors="pt")
21+
dc = DynamicCache()
22+
with torch.no_grad():
23+
rv = model(
24+
toks["input_ids"].to(device),
25+
attention_mask=toks["attention_mask"].to(device),
26+
past_key_values=dc,
27+
).past_key_values
28+
KV_CACHE[s] = rv
29+
return rv
30+
31+
32+
def merge(toks, dcs):
33+
merged_toks = torch.cat([t["input_ids"] for t in toks], dim=1)
34+
merged_masks = torch.cat([t["attention_mask"] for t in toks], dim=1)
35+
merged_dcs = merge_dynamic_caches(dcs)
36+
37+
return merged_toks, merged_masks, merged_dcs
38+
39+
40+
c_blocks = ["this is a test", "this is another test"]
41+
42+
# pretend this stuff already existed in the cahce.
43+
for cb in c_blocks:
44+
cache(cb)
45+
46+
47+
# apply the chat template to a conversation that contins these strings, but without tokenization.
48+
messages = [
49+
{"role": "user", "content": c_blocks[0]},
50+
{"role": "user", "content": "Not cached"},
51+
{"role": "user", "content": c_blocks[1]},
52+
{"role": "user", "content": "Also no cash"},
53+
]
54+
templatized_input = tokenizer.apply_chat_template(conversation=messages, tokenize=False)
55+
56+
str_parts = []
57+
tok_parts = []
58+
dc_parts = []
59+
60+
current_suffix = templatized_input
61+
partially_cached_templatized_input = list[str | DynamicCache]
62+
for cb in c_blocks:
63+
parts = current_suffix.split(cb)
64+
assert len(parts) == 2
65+
prefix, next_suffix = parts
66+
67+
if prefix != "":
68+
# Add the prefix.
69+
str_parts.append(prefix)
70+
# Add the tokens and attention mask for the prefix.
71+
tok_parts.append(tokenizer(prefix, return_tensors="pt"))
72+
# Add the dynamic cache for the prefix.
73+
dc_parts.append(cache(prefix, store=False))
74+
75+
# Add cb itself.
76+
str_parts.append(cb)
77+
tok_parts.append(tokenizer(cb, return_tensors="pt"))
78+
dc_parts.append(KV_CACHE[cb])
79+
80+
# set the current suffix.
81+
current_suffix = next_suffix
82+
83+
# REMEMBER: add the final suffix.
84+
if current_suffix != "":
85+
str_parts.append(current_suffix)
86+
tok_parts.append(tokenizer(current_suffix, return_tensors="pt"))
87+
dc_parts.append(cache(current_suffix, store=False))
88+
89+
# Merge evertything together.
90+
merged_toks = torch.cat([toks["input_ids"] for toks in tok_parts], dim=1)
91+
merged_masks = torch.cat([toks["attention_mask"] for toks in tok_parts], dim=1)
92+
merged_dcs = merge_dynamic_caches(dc_parts)
93+
94+
# crop the last KV for safety.
95+
merged_dcs.crop(-1)
96+
97+
# generate and print result.
98+
result = model.generate(
99+
merged_toks.to(device),
100+
attention_mask=merged_masks.to(device),
101+
past_key_values=merged_dcs,
102+
use_cache=True,
103+
return_dict_in_generate=True,
104+
output_scores=True,
105+
)
106+
107+
result_decoded = tokenizer.decode(
108+
result.sequences[0, merged_toks.shape[1] :], skip_special_tokens=True
109+
)
110+
print(result_decoded)

docs/kv_smash/kvcache.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
3+
from mellea.backends.huggingface import LocalHFBackend
4+
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches
5+
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
6+
from mellea.stdlib.base import CBlock, LinearContext
7+
from mellea.stdlib.chat import Message
8+
9+
backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B)
10+
11+
model = backend._model
12+
tokenizer = backend._tokenizer
13+
device = backend._device
14+
15+
16+
def cache(toks) -> DynamicCache:
17+
dc = DynamicCache()
18+
with torch.no_grad():
19+
rv = model(
20+
toks["input_ids"].to(device),
21+
attention_mask=toks["attention_mask"].to(device),
22+
past_key_values=dc,
23+
).past_key_values
24+
return rv
25+
26+
27+
def merge(strs: list[str]):
28+
strs_toks = [tokenizer(x, return_tensors="pt") for x in strs]
29+
strs_dcs = [cache(toks) for toks in strs_toks]
30+
31+
merged_toks = torch.cat([toks["input_ids"] for toks in strs_toks], dim=1)
32+
merged_masks = torch.cat([toks["attention_mask"] for toks in strs_toks], dim=1)
33+
merged_dcs = merge_dynamic_caches(strs_dcs)
34+
35+
return merged_toks, merged_masks, merged_dcs
36+
37+
38+
strs = ["this is a test", "this is another test"]
39+
40+
merged_toks, merged_masks, merged_dcs = merge(strs)
41+
merged_dcs.crop(-1)
42+
43+
result = model.generate(
44+
merged_toks.to(device),
45+
attention_mask=merged_masks.to(device),
46+
past_key_values=merged_dcs,
47+
use_cache=True,
48+
return_dict_in_generate=True,
49+
output_scores=True,
50+
)
51+
52+
result_decoded = tokenizer.decode(
53+
result.sequences[0, merged_toks.shape[1] :], skip_special_tokens=True
54+
)
55+
print(result_decoded)

0 commit comments

Comments
 (0)