From 49fdcf14b2d924355c80dc2c1d6c88c000abbcb9 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Wed, 27 Aug 2025 19:31:01 -0400 Subject: [PATCH 1/7] Adds cache smash code from the Project M codebase. --- mellea/backends/cache/kv_block_helpers.py | 58 +++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 mellea/backends/cache/kv_block_helpers.py diff --git a/mellea/backends/cache/kv_block_helpers.py b/mellea/backends/cache/kv_block_helpers.py new file mode 100644 index 00000000..f729ade3 --- /dev/null +++ b/mellea/backends/cache/kv_block_helpers.py @@ -0,0 +1,58 @@ +"""Utilities for KV smashing.""" + +from collections.abc import Iterable +from functools import reduce +from typing import Any + +import torch +from transformers import BatchEncoding, DynamicCache + +TokenizedCacheIterleaving = Iterable[BatchEncoding | DynamicCache] +LegacyCache = Any + + +def legacy_cache_smash(a: LegacyCache, b: LegacyCache) -> LegacyCache: + """Concatenates two LegacyCache Ks and Vs along the time axis.""" + legacy_merged = tuple( + (torch.cat([a[i][0], b[i][0]], dim=2), torch.cat([a[i][1], b[i][1]], dim=2)) + for i in range(len(a)) + ) + return legacy_merged + + +def merge_dynamic_caches(caches: Iterable[DynamicCache]) -> DynamicCache: + """Merges two DynamicCache Ks and Vs along the time axis.""" + legacies = [c.to_legacy_cache() for c in caches] + assert len(legacies) >= 1 + rv = DynamicCache.from_legacy_cache(reduce(legacy_cache_smash, legacies)) # type: ignore + return rv # type: ignore + + +def combine_representations( + tokenizer, reps: Iterable[str | DynamicCache] +) -> TokenizedCacheIterleaving: + rv = [] + for rep in reps: + if type(rep) is DynamicCache: + rv.append(rep) + else: + rv.append(tokenizer(rep)) + return rv + + +def tokens_to_legacy_cache( + model, device: str, tokens_or_cache: BatchEncoding | DynamicCache +) -> Iterable[LegacyCache]: + """Prefills and returns Ks and Vs as a LegacyCache.""" + if type(tokens_or_cache) is DynamicCache: + return tokens_or_cache.to_legacy_cache() + else: + tokens = tokens_or_cache + dc = DynamicCache() + with torch.no_grad(): + dc = model( + tokens["input_ids"].to(device), # type: ignore + attention_mask=tokens["attention_mask"].to(device), # type: ignore + past_key_values=dc, + ).past_key_values + return dc.to_legacy_cache() From a1a4eb70b970dbaecb87cb6c0cd77512710736c5 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Fri, 29 Aug 2025 09:07:30 -0400 Subject: [PATCH 2/7] rename to avoid clash b/w cache/ and cache.py --- mellea/backends/{cache => }/kv_block_helpers.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename mellea/backends/{cache => }/kv_block_helpers.py (100%) diff --git a/mellea/backends/cache/kv_block_helpers.py b/mellea/backends/kv_block_helpers.py similarity index 100% rename from mellea/backends/cache/kv_block_helpers.py rename to mellea/backends/kv_block_helpers.py From 5989664b6ca7cc7bf596fe48d5bff6b679c90085 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Fri, 29 Aug 2025 09:21:12 -0400 Subject: [PATCH 3/7] Adds cache flag to CBlock. --- mellea/stdlib/base.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 8f9fd817..30bd6fb3 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -15,11 +15,23 @@ class CBlock: """A `CBlock` is a block of content that can serve as input to or output from an LLM.""" - def __init__(self, value: str | None, meta: dict[str, Any] | None = None): - """Initializes the CBlock with a string and some metadata.""" + def __init__( + self, + value: str | None, + meta: dict[str, Any] | None = None, + *, + cache: bool = False, + ): + """Initializes the CBlock with a string and some metadata. + + Args: + value: the underlying value stored in this CBlock + meta: Any meta-information about this CBlock (e.g., the inference engine's Completion object). + cache: If set to `True` then this CBlock's KV cache might be stored by the inference engine. Experimental.""" if value is not None and not isinstance(value, str): raise TypeError("value to a Cblock should always be a string or None") self._underlying_value = value + self.cache = cache if meta is None: meta = {} self._meta = meta From fab35d92a739db299ef29d32ee8bc781137fae49 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Fri, 29 Aug 2025 09:21:49 -0400 Subject: [PATCH 4/7] Initial work on re-introducing span-ish KV caching. no-verify. --- mellea/backends/huggingface.py | 260 ++++++++++++++++++++++++++++++++- 1 file changed, 259 insertions(+), 1 deletion(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index b3c4d09a..6a7204c0 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -26,7 +26,7 @@ set_seed, ) -from mellea.backends import BaseModelSubclass +from mellea.backends import BaseModelSubclass, kv_block_helpers from mellea.backends.aloras import Alora, AloraBackendMixin from mellea.backends.cache import Cache, SimpleLRUCache from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter @@ -272,6 +272,264 @@ def _generate_from_context_alora( ), ) + _cached_blocks = {} + _cached_toks = {} + + def _generate_from_context_with_kv_cache( + self, + action: Component | CBlock, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict[str, Any] = {}, + generate_logs: list[GenerateLog] | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk: + # Construct input. + # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template. + # Otherwise, we will linearize the context and treat it as a raw input. + decoded_result: str | None = None + if ctx.is_chat_context: + linearized_ctx = ctx.render_for_generation() + + assert linearized_ctx is not None, ( + "If ctx.is_chat_context, then the context should be linearizable." + ) + ctx_as_message_list: list[Message] = self.formatter.to_chat_messages( + linearized_ctx + ) + # add action + ctx_as_message_list.extend(self.formatter.to_chat_messages([action])) + + ctx_as_conversation = [ + {"role": m.role, "content": m.content} for m in ctx_as_message_list + ] + + # Check that we ddin't accidentally end up with CBlocks. + for msg in ctx_as_conversation: + for v in msg.values(): + if "CBlock" in v: + FancyLogger.get_logger().error( + f"Found the string `CBlock` in what should've been a stringified context: {ctx_as_conversation}" + ) + + # handle custom system prompts. It's important that we do this before the _parse_and_**clean**_model_options step. + system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None) + if system_prompt is not None: + system_msg: dict[str, str] = { + "role": "system", + "content": system_prompt, + } + ctx_as_conversation.insert(0, system_msg) + + # Append tool call information if applicable. + tools: dict[str, Callable] = dict() + if tool_calls: + if format: + FancyLogger.get_logger().warning( + f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" + ) + else: + if isinstance(action, Component) and isinstance( + action.format_for_llm(), TemplateRepresentation + ): + tools = get_tools_from_action(action) + + model_options_tools = model_options.get(ModelOption.TOOLS, None) + if model_options_tools is not None: + assert isinstance(model_options_tools, dict) + for fn_name in model_options_tools: + # invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools + assert fn_name not in tools.keys(), ( + f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action." + ) + # type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries. + assert type(fn_name) is str, ( + "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." + ) + assert callable(model_options_tools[fn_name]), ( + "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." + ) + # Add the model_options tool to the existing set of tools. + tools[fn_name] = model_options_tools[fn_name] + + seed = model_options.get(ModelOption.SEED, None) + if seed is not None: + set_seed(seed) + + # Explanation for code blocks inside of use_kv_cache checks: + # 1. cache every CBlock that is marked with `cache=True` and store in _cached_blocks. + # 2. Mark each "hit" by adding the string (tokenized?) value to `cached_block_keys`. + # 3. apply the chat template (without?) tokenization + # 4. split on cache hits + # 5. prefill + smash together everything. + # 6. generate + + # 1. cache every CBlock that is marked with `cache=True` and store in _cached_blocks. + # AND + # 2. Mark each "hit" by adding the string (tokenized?) value to `cached_block_keys`. + cached_block_keys = [] + for c in linearized_ctx: + match c: + case CBlock() if c.cache: + if c.value not in self._cached_blocks: + FancyLogger.get_logger().info(f"Caching {hash(c.value)}") + tokens = self._tokenizer(c.value) + dc = DynamicCache() + with torch.no_grad(): + dc = self._model( + tokens["input_ids"].to(self._device), # type: ignore + attention_mask=tokens["attention_mask"].to( + self._device + ), # type: ignore + past_key_values=dc, + ).past_key_values + legacy_cache = dc.to_legacy_cache() + self._cached_blocks[c.value] = legacy_cache + self._cached_toks[c.value] = tokens + cached_block_keys.append(c.value) + case _: + continue + + # 3. apply the chat template without tokenization. + input_text = self._tokenizer.apply_chat_template( # type: ignore + ctx_as_conversation, + tools=convert_tools_to_json(tools), # type: ignore + **self._make_backend_specific_and_remove(model_options), + tokenize=False, + ) + + # 4. split on cache hits + parts: list[str | tuple[DynamicCache, Any]] = [input_text] + for key in cached_block_keys: + next_split = parts.pop() + parts_split = next_split.split(key) + assert len(parts_split) == 2, ( + "Known issue: cached substring might occur more than once. We need to handle this situation earlier. Notice if this happens and keep a count." + ) + parts.append(parts_split[0]) + parts.append((self._cached_blocks[key], self._cached_toks[key])) + parts.append(parts_split[1]) + + # 5. prefill + smash together everything. + prefilled: Any | None = None + parts_tokens: list[Any] = [] + for part in parts: + if type(part) is str: + part_toks = self._tokenizer( + part, + return_tensors="pt", + **self._make_backend_specific_and_remove(model_options), + ) + parts_tokens.append(part_toks) + part_legacy_cache = kv_block_helpers.tokens_to_legacy_cache( + self._model, self._device, part_toks + ) + prefilled = ( + part_legacy_cache + if prefilled is None + else kv_block_helpers.legacy_cache_smash( + prefilled, part_legacy_cache + ) + ) + else: + parts_tokens.append(part[1]) + prefilled = ( + part[0] + if prefilled is None + else kv_block_helpers.legacy_cache_smash( + prefilled, part_legacy_cache + ) + ) + + # also smash together the tokens. + input_ids = torch.cat([toks["input_ids"] for toks in parts_tokens], dim=1) + + if format is None: + chat_output = self._model.generate( # type: ignore + input_ids, + return_dict_in_generate=True, + output_scores=True, + **self._make_backend_specific_and_remove(model_options), + ) # type: ignore + + else: + # outlines.generate.json always parses the resulting json into a python dict. + # We however want to keep it as a json string for later storing it in ModelOutputThunk + schema: dict[str, Any] = format.model_json_schema() + schema_json: str = json.dumps(schema) + regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( + schema_json + ) + + from outlines.models.transformers import TransformerTokenizer + from outlines.processors import RegexLogitsProcessor + from transformers import LogitsProcessorList + + chat_output = self._model.generate( # type: ignore + input_ids, + return_dict_in_generate=True, + output_scores=True, + logits_processor=LogitsProcessorList( + [ + RegexLogitsProcessor( + regex_str, + tokenizer=TransformerTokenizer(self._tokenizer), + ) + ] + ), + **self._make_backend_specific_and_remove(model_options), + ) + + decoded_result = self._tokenizer.decode( + chat_output.sequences[0, input_ids.shape[1] :], skip_special_tokens=True + ) + + # Add an entry to the cache for ALora reuse. + if self._use_caches: + output_complete = chat_output.sequences[0] + cache: DynamicCache = chat_output.past_key_values + + cache_info = HFAloraCacheInfo( + kv_cache=cache, + merged_token_ids=output_complete, + merged_attention=torch.ones_like(output_complete).to(self._device), + q_end=len(input_ids[0]), + ) + + assert decoded_result is not None + self.cache_put(decoded_result, cache_info) + else: + raise Exception("Does not yet support non-chat contexts.") + + assert decoded_result is not None + + result = ModelOutputThunk(value=decoded_result) + + # Only scan for tools if we are not doing structured decoding and tool calls were provided to the model. + if format is None and tool_calls: + result.tool_calls = self._extract_model_tool_requests(tools, decoded_result) + + parsed_result = self.formatter.parse(action, result) + if generate_logs is not None: + assert isinstance(generate_logs, list) + generate_log = GenerateLog() + generate_log.prompt = ctx_as_conversation + generate_log.backend = f"hf::{self.model_id!s}" + generate_log.model_options = model_options + generate_log.date = datetime.datetime.now() + generate_log.model_output = decoded_result + generate_log.extra = { + "format": format, + "tools_available": tools, + "tools_called": result.tool_calls, + "seed": seed, + } + generate_log.action = action + generate_log.result = parsed_result + generate_logs.append(generate_log) + return parsed_result + def _generate_from_context_standard( self, action: Component | CBlock, From a648405d1ae9b214d88401cad1f80cdd253feea2 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Fri, 29 Aug 2025 13:34:33 -0400 Subject: [PATCH 5/7] Adds a crystallization of the kv smash code And the way it fits into a model that uses apply_chat_tempalte or any other parser/renderer. Note that there's still a bug entailed by the chance that there are also substrings which "hit" on the cached contents. We don't anticipate this happens often in practice because of how KV cache smashing should typically be used, but it's something we need to address by introducing the use of sentinel values, or indexing string machines, or something else along those lines. no-verify commit because the point of this code is documentation. --- docs/kv_smash/kv_with_chat.py | 110 ++++++++++++++++++++++++++++++++++ docs/kv_smash/kvcache.py | 55 +++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 docs/kv_smash/kv_with_chat.py create mode 100644 docs/kv_smash/kvcache.py diff --git a/docs/kv_smash/kv_with_chat.py b/docs/kv_smash/kv_with_chat.py new file mode 100644 index 00000000..f5a249c8 --- /dev/null +++ b/docs/kv_smash/kv_with_chat.py @@ -0,0 +1,110 @@ +import torch + +from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B +from mellea.stdlib.base import CBlock, LinearContext +from mellea.stdlib.chat import Message + +backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B) + +model = backend._model +tokenizer = backend._tokenizer +device = backend._device + + +KV_CACHE: dict[str, DynamicCache] = dict() + + +def cache(s: str, store=True) -> DynamicCache: + toks = tokenizer(s, return_tensors="pt") + dc = DynamicCache() + with torch.no_grad(): + rv = model( + toks["input_ids"].to(device), + attention_mask=toks["attention_mask"].to(device), + past_key_values=dc, + ).past_key_values + KV_CACHE[s] = rv + return rv + + +def merge(toks, dcs): + merged_toks = torch.cat([t["input_ids"] for t in toks], dim=1) + merged_masks = torch.cat([t["attention_mask"] for t in toks], dim=1) + merged_dcs = merge_dynamic_caches(dcs) + + return merged_toks, merged_masks, merged_dcs + + +c_blocks = ["this is a test", "this is another test"] + +# pretend this stuff already existed in the cahce. +for cb in c_blocks: + cache(cb) + + +# apply the chat template to a conversation that contins these strings, but without tokenization. +messages = [ + {"role": "user", "content": c_blocks[0]}, + {"role": "user", "content": "Not cached"}, + {"role": "user", "content": c_blocks[1]}, + {"role": "user", "content": "Also no cash"}, +] +templatized_input = tokenizer.apply_chat_template(conversation=messages, tokenize=False) + +str_parts = [] +tok_parts = [] +dc_parts = [] + +current_suffix = templatized_input +partially_cached_templatized_input = list[str | DynamicCache] +for cb in c_blocks: + parts = current_suffix.split(cb) + assert len(parts) == 2 + prefix, next_suffix = parts + + if prefix != "": + # Add the prefix. + str_parts.append(prefix) + # Add the tokens and attention mask for the prefix. + tok_parts.append(tokenizer(prefix, return_tensors="pt")) + # Add the dynamic cache for the prefix. + dc_parts.append(cache(prefix, store=False)) + + # Add cb itself. + str_parts.append(cb) + tok_parts.append(tokenizer(cb, return_tensors="pt")) + dc_parts.append(KV_CACHE[cb]) + + # set the current suffix. + current_suffix = next_suffix + +# REMEMBER: add the final suffix. +if current_suffix != "": + str_parts.append(current_suffix) + tok_parts.append(tokenizer(current_suffix, return_tensors="pt")) + dc_parts.append(cache(current_suffix, store=False)) + +# Merge evertything together. +merged_toks = torch.cat([toks["input_ids"] for toks in tok_parts], dim=1) +merged_masks = torch.cat([toks["attention_mask"] for toks in tok_parts], dim=1) +merged_dcs = merge_dynamic_caches(dc_parts) + +# crop the last KV for safety. +merged_dcs.crop(-1) + +# generate and print result. +result = model.generate( + merged_toks.to(device), + attention_mask=merged_masks.to(device), + past_key_values=merged_dcs, + use_cache=True, + return_dict_in_generate=True, + output_scores=True, +) + +result_decoded = tokenizer.decode( + result.sequences[0, merged_toks.shape[1] :], skip_special_tokens=True +) +print(result_decoded) diff --git a/docs/kv_smash/kvcache.py b/docs/kv_smash/kvcache.py new file mode 100644 index 00000000..15b4d9b7 --- /dev/null +++ b/docs/kv_smash/kvcache.py @@ -0,0 +1,55 @@ +import torch + +from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B +from mellea.stdlib.base import CBlock, LinearContext +from mellea.stdlib.chat import Message + +backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B) + +model = backend._model +tokenizer = backend._tokenizer +device = backend._device + + +def cache(toks) -> DynamicCache: + dc = DynamicCache() + with torch.no_grad(): + rv = model( + toks["input_ids"].to(device), + attention_mask=toks["attention_mask"].to(device), + past_key_values=dc, + ).past_key_values + return rv + + +def merge(strs: list[str]): + strs_toks = [tokenizer(x, return_tensors="pt") for x in strs] + strs_dcs = [cache(toks) for toks in strs_toks] + + merged_toks = torch.cat([toks["input_ids"] for toks in strs_toks], dim=1) + merged_masks = torch.cat([toks["attention_mask"] for toks in strs_toks], dim=1) + merged_dcs = merge_dynamic_caches(strs_dcs) + + return merged_toks, merged_masks, merged_dcs + + +strs = ["this is a test", "this is another test"] + +merged_toks, merged_masks, merged_dcs = merge(strs) +merged_dcs.crop(-1) + +result = model.generate( + merged_toks.to(device), + attention_mask=merged_masks.to(device), + past_key_values=merged_dcs, + use_cache=True, + return_dict_in_generate=True, + output_scores=True, +) + +result_decoded = tokenizer.decode( + result.sequences[0, merged_toks.shape[1] :], skip_special_tokens=True +) +print(result_decoded) From ead3fe837cba475d43c8cb846dd1addc6523ac07 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Fri, 29 Aug 2025 14:22:13 -0400 Subject: [PATCH 6/7] Adds KV cache smash. --- mellea/backends/huggingface.py | 141 +++++++++++++++++++++------------ 1 file changed, 89 insertions(+), 52 deletions(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 6a7204c0..9c81142d 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -272,10 +272,20 @@ def _generate_from_context_alora( ), ) - _cached_blocks = {} - _cached_toks = {} - - def _generate_from_context_with_kv_cache( + _cached_blocks: dict[str, DynamicCache] = dict() + + def _make_dc_cache(self, toks, **model_options): + dc = DynamicCache() + with torch.no_grad(): + dc = self._model( + toks["input_ids"].to(self._device), + attention_mask=toks["attention_mask"].to(self._device), + past_key_values=dc, + **model_options, + ).past_key_values + return dc + + def _generate_from_context_with_kv_cache( # noqa: C901 self, action: Component | CBlock, ctx: Context, @@ -372,9 +382,16 @@ def _generate_from_context_with_kv_cache( for c in linearized_ctx: match c: case CBlock() if c.cache: - if c.value not in self._cached_blocks: - FancyLogger.get_logger().info(f"Caching {hash(c.value)}") - tokens = self._tokenizer(c.value) + assert c.value is not None + if c.value in self._cached_blocks: + FancyLogger.get_logger().info( + f"KV CACHE HIT for: {hash(c.value)} ({c.value[:3]}..{c.value[-3:]})" # type: ignore + ) + else: + FancyLogger.get_logger().debug( + f"HF backend is caching a CBlock with hashed contents: {hash(c.value)} ({c.value[:3]}..{c.value[-3:]})" + ) + tokens = self._tokenizer(c.value, return_tensors="pt") dc = DynamicCache() with torch.no_grad(): dc = self._model( @@ -383,15 +400,16 @@ def _generate_from_context_with_kv_cache( self._device ), # type: ignore past_key_values=dc, + use_cache=True, ).past_key_values - legacy_cache = dc.to_legacy_cache() - self._cached_blocks[c.value] = legacy_cache - self._cached_toks[c.value] = tokens + self._cached_blocks[c.value] = dc cached_block_keys.append(c.value) case _: continue - # 3. apply the chat template without tokenization. + # 3. apply the chat template WITHOUT tokenization. + # Doing this without tokenization and then gluing together the tokens is necessary because + # things that KV cache together must tokenize together. input_text = self._tokenizer.apply_chat_template( # type: ignore ctx_as_conversation, tools=convert_tools_to_json(tools), # type: ignore @@ -399,61 +417,80 @@ def _generate_from_context_with_kv_cache( tokenize=False, ) - # 4. split on cache hits - parts: list[str | tuple[DynamicCache, Any]] = [input_text] + # 4. split the input_text back up again, re-using DC where it exists. + str_parts = [] + tok_parts = [] + dc_parts = [] + current_suffix = input_text for key in cached_block_keys: - next_split = parts.pop() - parts_split = next_split.split(key) - assert len(parts_split) == 2, ( + assert key is not None, ( + "Some input CBlock must not have bee ncomputed yet? The error comes far before this line." + ) + assert key in current_suffix, ( + "Could happen but would be rare. related to the other assert in this block." + ) + parts = current_suffix.split(key) # type: ignore + assert len(parts) == 2, ( "Known issue: cached substring might occur more than once. We need to handle this situation earlier. Notice if this happens and keep a count." ) - parts.append(parts_split[0]) - parts.append((self._cached_blocks[key], self._cached_toks[key])) - parts.append(parts_split[1]) - - # 5. prefill + smash together everything. - prefilled: Any | None = None - parts_tokens: list[Any] = [] - for part in parts: - if type(part) is str: - part_toks = self._tokenizer( - part, - return_tensors="pt", - **self._make_backend_specific_and_remove(model_options), - ) - parts_tokens.append(part_toks) - part_legacy_cache = kv_block_helpers.tokens_to_legacy_cache( - self._model, self._device, part_toks - ) - prefilled = ( - part_legacy_cache - if prefilled is None - else kv_block_helpers.legacy_cache_smash( - prefilled, part_legacy_cache - ) - ) - else: - parts_tokens.append(part[1]) - prefilled = ( - part[0] - if prefilled is None - else kv_block_helpers.legacy_cache_smash( - prefilled, part_legacy_cache - ) + prefix, suffix = parts + # Add the prefix, if any, to str+tok+dc parts. + if prefix != "": + FancyLogger.get_logger().debug( + f"Doing a forward pass on uncached block which is prefix to a cached CBlock: {prefix[:3]}.{len(prefix)}.{prefix[-3:]}" ) + str_parts.append(prefix) + tok_parts.append(self._tokenizer(prefix, return_tensors="pt")) + dc_parts.append(self._make_dc_cache(tok_parts[-1])) + # Add the cached CBlock to str+tok+dc parts. + FancyLogger.get_logger().debug( + f"Replacing a substring with previously computed/retrieved cache with hahs value {hash(key)} ({key[:3]}..{key[-3:]})" + ) + # str_parts.append(key) + # tok_parts.append(self._tokenizer(key, return_tensors="pt")) + # dc_parts.append(self._make_dc_cache(tok_parts[-1])) # TODO this is wrong. + str_parts.append(key) + tok_parts.append(self._tokenizer(key, return_tensors="pt")) + dc_parts.append(self._cached_blocks[key]) + # set the suffix for the next loop iteration. + current_suffix = suffix + # "base" case: the final suffix. + if current_suffix != "": + FancyLogger.get_logger().debug( # type: ignore + f"Doing a forward pass on final suffix, an uncached block: {current_suffix[:3]}.{len(current_suffix)}.{current_suffix[-3:]}" # type: ignore + ) # type: ignore + str_parts.append(current_suffix) + tok_parts.append(self._tokenizer(current_suffix, return_tensors="pt")) + dc_parts.append(self._make_dc_cache(tok_parts[-1])) - # also smash together the tokens. - input_ids = torch.cat([toks["input_ids"] for toks in parts_tokens], dim=1) + # Smash together the caches, the input_ids, and the attention masks. + assert "".join(str_parts) == input_text, ( + "Should've ended up with the same input text!" + ) + input_ids = torch.cat([toks["input_ids"] for toks in tok_parts], dim=1) + attention_mask = torch.cat( + [toks["attention_mask"] for toks in tok_parts], dim=1 + ) + assert input_ids.shape == attention_mask.shape + merged_cache: DynamicCache = kv_block_helpers.merge_dynamic_caches(dc_parts) + # TODO: also assert that the merged cached is the correct shape given the input_ids and attention_mask shapes. + + # rewind merged cache by 1 for safety. + merged_cache.crop(-1) if format is None: chat_output = self._model.generate( # type: ignore - input_ids, + input_ids.to(self._device), + attention_mask=attention_mask.to(self._device), + use_cache=True, + past_key_values=merged_cache, return_dict_in_generate=True, output_scores=True, **self._make_backend_specific_and_remove(model_options), ) # type: ignore else: + raise NotImplementedError("Copy implementation from above.") # outlines.generate.json always parses the resulting json into a python dict. # We however want to keep it as a json string for later storing it in ModelOutputThunk schema: dict[str, Any] = format.model_json_schema() From 1cd08ae523bd7367adcb93232392885f000d8af9 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Fri, 29 Aug 2025 14:22:57 -0400 Subject: [PATCH 7/7] Adds example of kv cache smash. --- docs/kv_smash/hf_example.py | 39 +++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 docs/kv_smash/hf_example.py diff --git a/docs/kv_smash/hf_example.py b/docs/kv_smash/hf_example.py new file mode 100644 index 00000000..8b9706b1 --- /dev/null +++ b/docs/kv_smash/hf_example.py @@ -0,0 +1,39 @@ +from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, LinearContext +from mellea.stdlib.chat import Message + +ctx = LinearContext(window_size=100) +ctx.insert( + CBlock( + "Nathan Fulton is a Senior Research Scientist at the MIT-IBM Watson AI Lab, a joint venture between MIT and IBM.", + cache=True, + ) +) +ctx.insert( + CBlock( + "The MIT-IBM Watson AI Lab is located at 314 Main St, Cambridge, Massachusetts.", + cache=True, + ) +) +ctx.insert(CBlock("The ZIP code for 314 Main St, Cambridge, Massachusetts is 02142")) + + +msg = Message( + role="user", content="What is the likely ZIP code of Nathan Fulton's work address." +) +backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B) +result = backend._generate_from_context_with_kv_cache( + action=msg, ctx=ctx, model_options={ModelOption.MAX_NEW_TOKENS: 1000} +) +print(f".{result}.") + +msg2 = Message( + role="user", + content="We know that Nathan does not work for a university. What is the likely name of Nathan's employer?", +) +result = backend._generate_from_context_with_kv_cache( + action=msg2, ctx=ctx, model_options={ModelOption.MAX_NEW_TOKENS: 1000} +) +print(f".{result}.")