|
| 1 | +import torch |
| 2 | + |
| 3 | +from mellea.backends.huggingface import LocalHFBackend |
| 4 | +from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches |
| 5 | +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B |
| 6 | +from mellea.stdlib.base import CBlock, LinearContext |
| 7 | +from mellea.stdlib.chat import Message |
| 8 | + |
| 9 | +backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B) |
| 10 | + |
| 11 | +model = backend._model |
| 12 | +tokenizer = backend._tokenizer |
| 13 | +device = backend._device |
| 14 | + |
| 15 | + |
| 16 | +KV_CACHE: dict[str, DynamicCache] = dict() |
| 17 | + |
| 18 | + |
| 19 | +def cache(s: str, store=True) -> DynamicCache: |
| 20 | + toks = tokenizer(s, return_tensors="pt") |
| 21 | + dc = DynamicCache() |
| 22 | + with torch.no_grad(): |
| 23 | + rv = model( |
| 24 | + toks["input_ids"].to(device), |
| 25 | + attention_mask=toks["attention_mask"].to(device), |
| 26 | + past_key_values=dc, |
| 27 | + ).past_key_values |
| 28 | + KV_CACHE[s] = rv |
| 29 | + return rv |
| 30 | + |
| 31 | + |
| 32 | +def merge(toks, dcs): |
| 33 | + merged_toks = torch.cat([t["input_ids"] for t in toks], dim=1) |
| 34 | + merged_masks = torch.cat([t["attention_mask"] for t in toks], dim=1) |
| 35 | + merged_dcs = merge_dynamic_caches(dcs) |
| 36 | + |
| 37 | + return merged_toks, merged_masks, merged_dcs |
| 38 | + |
| 39 | + |
| 40 | +c_blocks = ["this is a test", "this is another test"] |
| 41 | + |
| 42 | +# pretend this stuff already existed in the cahce. |
| 43 | +for cb in c_blocks: |
| 44 | + cache(cb) |
| 45 | + |
| 46 | + |
| 47 | +# apply the chat template to a conversation that contins these strings, but without tokenization. |
| 48 | +messages = [ |
| 49 | + {"role": "user", "content": c_blocks[0]}, |
| 50 | + {"role": "user", "content": "Not cached"}, |
| 51 | + {"role": "user", "content": c_blocks[1]}, |
| 52 | + {"role": "user", "content": "Also no cash"}, |
| 53 | +] |
| 54 | +templatized_input = tokenizer.apply_chat_template(conversation=messages, tokenize=False) |
| 55 | + |
| 56 | +str_parts = [] |
| 57 | +tok_parts = [] |
| 58 | +dc_parts = [] |
| 59 | + |
| 60 | +current_suffix = templatized_input |
| 61 | +partially_cached_templatized_input = list[str | DynamicCache] |
| 62 | +for cb in c_blocks: |
| 63 | + parts = current_suffix.split(cb) |
| 64 | + assert len(parts) == 2 |
| 65 | + prefix, next_suffix = parts |
| 66 | + |
| 67 | + if prefix != "": |
| 68 | + # Add the prefix. |
| 69 | + str_parts.append(prefix) |
| 70 | + # Add the tokens and attention mask for the prefix. |
| 71 | + tok_parts.append(tokenizer(prefix, return_tensors="pt")) |
| 72 | + # Add the dynamic cache for the prefix. |
| 73 | + dc_parts.append(cache(prefix, store=False)) |
| 74 | + |
| 75 | + # Add cb itself. |
| 76 | + str_parts.append(cb) |
| 77 | + tok_parts.append(tokenizer(cb, return_tensors="pt")) |
| 78 | + dc_parts.append(KV_CACHE[cb]) |
| 79 | + |
| 80 | + # set the current suffix. |
| 81 | + current_suffix = next_suffix |
| 82 | + |
| 83 | +# REMEMBER: add the final suffix. |
| 84 | +if current_suffix != "": |
| 85 | + str_parts.append(current_suffix) |
| 86 | + tok_parts.append(tokenizer(current_suffix, return_tensors="pt")) |
| 87 | + dc_parts.append(cache(current_suffix, store=False)) |
| 88 | + |
| 89 | +# Merge evertything together. |
| 90 | +merged_toks = torch.cat([toks["input_ids"] for toks in tok_parts], dim=1) |
| 91 | +merged_masks = torch.cat([toks["attention_mask"] for toks in tok_parts], dim=1) |
| 92 | +merged_dcs = merge_dynamic_caches(dc_parts) |
| 93 | + |
| 94 | +# crop the last KV for safety. |
| 95 | +merged_dcs.crop(-1) |
| 96 | + |
| 97 | +# generate and print result. |
| 98 | +result = model.generate( |
| 99 | + merged_toks.to(device), |
| 100 | + attention_mask=merged_masks.to(device), |
| 101 | + past_key_values=merged_dcs, |
| 102 | + use_cache=True, |
| 103 | + return_dict_in_generate=True, |
| 104 | + output_scores=True, |
| 105 | +) |
| 106 | + |
| 107 | +result_decoded = tokenizer.decode( |
| 108 | + result.sequences[0, merged_toks.shape[1] :], skip_special_tokens=True |
| 109 | +) |
| 110 | +print(result_decoded) |
0 commit comments