Skip to content

Commit fab35d9

Browse files
committed
Initial work on re-introducing span-ish KV caching.
no-verify.
1 parent 5989664 commit fab35d9

File tree

1 file changed

+259
-1
lines changed

1 file changed

+259
-1
lines changed

mellea/backends/huggingface.py

Lines changed: 259 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
set_seed,
2727
)
2828

29-
from mellea.backends import BaseModelSubclass
29+
from mellea.backends import BaseModelSubclass, kv_block_helpers
3030
from mellea.backends.aloras import Alora, AloraBackendMixin
3131
from mellea.backends.cache import Cache, SimpleLRUCache
3232
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
@@ -272,6 +272,264 @@ def _generate_from_context_alora(
272272
),
273273
)
274274

275+
_cached_blocks = {}
276+
_cached_toks = {}
277+
278+
def _generate_from_context_with_kv_cache(
279+
self,
280+
action: Component | CBlock,
281+
ctx: Context,
282+
*,
283+
format: type[BaseModelSubclass] | None = None,
284+
model_options: dict[str, Any] = {},
285+
generate_logs: list[GenerateLog] | None = None,
286+
tool_calls: bool = False,
287+
) -> ModelOutputThunk:
288+
# Construct input.
289+
# If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template.
290+
# Otherwise, we will linearize the context and treat it as a raw input.
291+
decoded_result: str | None = None
292+
if ctx.is_chat_context:
293+
linearized_ctx = ctx.render_for_generation()
294+
295+
assert linearized_ctx is not None, (
296+
"If ctx.is_chat_context, then the context should be linearizable."
297+
)
298+
ctx_as_message_list: list[Message] = self.formatter.to_chat_messages(
299+
linearized_ctx
300+
)
301+
# add action
302+
ctx_as_message_list.extend(self.formatter.to_chat_messages([action]))
303+
304+
ctx_as_conversation = [
305+
{"role": m.role, "content": m.content} for m in ctx_as_message_list
306+
]
307+
308+
# Check that we ddin't accidentally end up with CBlocks.
309+
for msg in ctx_as_conversation:
310+
for v in msg.values():
311+
if "CBlock" in v:
312+
FancyLogger.get_logger().error(
313+
f"Found the string `CBlock` in what should've been a stringified context: {ctx_as_conversation}"
314+
)
315+
316+
# handle custom system prompts. It's important that we do this before the _parse_and_**clean**_model_options step.
317+
system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None)
318+
if system_prompt is not None:
319+
system_msg: dict[str, str] = {
320+
"role": "system",
321+
"content": system_prompt,
322+
}
323+
ctx_as_conversation.insert(0, system_msg)
324+
325+
# Append tool call information if applicable.
326+
tools: dict[str, Callable] = dict()
327+
if tool_calls:
328+
if format:
329+
FancyLogger.get_logger().warning(
330+
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
331+
)
332+
else:
333+
if isinstance(action, Component) and isinstance(
334+
action.format_for_llm(), TemplateRepresentation
335+
):
336+
tools = get_tools_from_action(action)
337+
338+
model_options_tools = model_options.get(ModelOption.TOOLS, None)
339+
if model_options_tools is not None:
340+
assert isinstance(model_options_tools, dict)
341+
for fn_name in model_options_tools:
342+
# invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools
343+
assert fn_name not in tools.keys(), (
344+
f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action."
345+
)
346+
# type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries.
347+
assert type(fn_name) is str, (
348+
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
349+
)
350+
assert callable(model_options_tools[fn_name]), (
351+
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
352+
)
353+
# Add the model_options tool to the existing set of tools.
354+
tools[fn_name] = model_options_tools[fn_name]
355+
356+
seed = model_options.get(ModelOption.SEED, None)
357+
if seed is not None:
358+
set_seed(seed)
359+
360+
# Explanation for code blocks inside of use_kv_cache checks:
361+
# 1. cache every CBlock that is marked with `cache=True` and store in _cached_blocks.
362+
# 2. Mark each "hit" by adding the string (tokenized?) value to `cached_block_keys`.
363+
# 3. apply the chat template (without?) tokenization
364+
# 4. split on cache hits
365+
# 5. prefill + smash together everything.
366+
# 6. generate
367+
368+
# 1. cache every CBlock that is marked with `cache=True` and store in _cached_blocks.
369+
# AND
370+
# 2. Mark each "hit" by adding the string (tokenized?) value to `cached_block_keys`.
371+
cached_block_keys = []
372+
for c in linearized_ctx:
373+
match c:
374+
case CBlock() if c.cache:
375+
if c.value not in self._cached_blocks:
376+
FancyLogger.get_logger().info(f"Caching {hash(c.value)}")
377+
tokens = self._tokenizer(c.value)
378+
dc = DynamicCache()
379+
with torch.no_grad():
380+
dc = self._model(
381+
tokens["input_ids"].to(self._device), # type: ignore
382+
attention_mask=tokens["attention_mask"].to(
383+
self._device
384+
), # type: ignore
385+
past_key_values=dc,
386+
).past_key_values
387+
legacy_cache = dc.to_legacy_cache()
388+
self._cached_blocks[c.value] = legacy_cache
389+
self._cached_toks[c.value] = tokens
390+
cached_block_keys.append(c.value)
391+
case _:
392+
continue
393+
394+
# 3. apply the chat template without tokenization.
395+
input_text = self._tokenizer.apply_chat_template( # type: ignore
396+
ctx_as_conversation,
397+
tools=convert_tools_to_json(tools), # type: ignore
398+
**self._make_backend_specific_and_remove(model_options),
399+
tokenize=False,
400+
)
401+
402+
# 4. split on cache hits
403+
parts: list[str | tuple[DynamicCache, Any]] = [input_text]
404+
for key in cached_block_keys:
405+
next_split = parts.pop()
406+
parts_split = next_split.split(key)
407+
assert len(parts_split) == 2, (
408+
"Known issue: cached substring might occur more than once. We need to handle this situation earlier. Notice if this happens and keep a count."
409+
)
410+
parts.append(parts_split[0])
411+
parts.append((self._cached_blocks[key], self._cached_toks[key]))
412+
parts.append(parts_split[1])
413+
414+
# 5. prefill + smash together everything.
415+
prefilled: Any | None = None
416+
parts_tokens: list[Any] = []
417+
for part in parts:
418+
if type(part) is str:
419+
part_toks = self._tokenizer(
420+
part,
421+
return_tensors="pt",
422+
**self._make_backend_specific_and_remove(model_options),
423+
)
424+
parts_tokens.append(part_toks)
425+
part_legacy_cache = kv_block_helpers.tokens_to_legacy_cache(
426+
self._model, self._device, part_toks
427+
)
428+
prefilled = (
429+
part_legacy_cache
430+
if prefilled is None
431+
else kv_block_helpers.legacy_cache_smash(
432+
prefilled, part_legacy_cache
433+
)
434+
)
435+
else:
436+
parts_tokens.append(part[1])
437+
prefilled = (
438+
part[0]
439+
if prefilled is None
440+
else kv_block_helpers.legacy_cache_smash(
441+
prefilled, part_legacy_cache
442+
)
443+
)
444+
445+
# also smash together the tokens.
446+
input_ids = torch.cat([toks["input_ids"] for toks in parts_tokens], dim=1)
447+
448+
if format is None:
449+
chat_output = self._model.generate( # type: ignore
450+
input_ids,
451+
return_dict_in_generate=True,
452+
output_scores=True,
453+
**self._make_backend_specific_and_remove(model_options),
454+
) # type: ignore
455+
456+
else:
457+
# outlines.generate.json always parses the resulting json into a python dict.
458+
# We however want to keep it as a json string for later storing it in ModelOutputThunk
459+
schema: dict[str, Any] = format.model_json_schema()
460+
schema_json: str = json.dumps(schema)
461+
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema(
462+
schema_json
463+
)
464+
465+
from outlines.models.transformers import TransformerTokenizer
466+
from outlines.processors import RegexLogitsProcessor
467+
from transformers import LogitsProcessorList
468+
469+
chat_output = self._model.generate( # type: ignore
470+
input_ids,
471+
return_dict_in_generate=True,
472+
output_scores=True,
473+
logits_processor=LogitsProcessorList(
474+
[
475+
RegexLogitsProcessor(
476+
regex_str,
477+
tokenizer=TransformerTokenizer(self._tokenizer),
478+
)
479+
]
480+
),
481+
**self._make_backend_specific_and_remove(model_options),
482+
)
483+
484+
decoded_result = self._tokenizer.decode(
485+
chat_output.sequences[0, input_ids.shape[1] :], skip_special_tokens=True
486+
)
487+
488+
# Add an entry to the cache for ALora reuse.
489+
if self._use_caches:
490+
output_complete = chat_output.sequences[0]
491+
cache: DynamicCache = chat_output.past_key_values
492+
493+
cache_info = HFAloraCacheInfo(
494+
kv_cache=cache,
495+
merged_token_ids=output_complete,
496+
merged_attention=torch.ones_like(output_complete).to(self._device),
497+
q_end=len(input_ids[0]),
498+
)
499+
500+
assert decoded_result is not None
501+
self.cache_put(decoded_result, cache_info)
502+
else:
503+
raise Exception("Does not yet support non-chat contexts.")
504+
505+
assert decoded_result is not None
506+
507+
result = ModelOutputThunk(value=decoded_result)
508+
509+
# Only scan for tools if we are not doing structured decoding and tool calls were provided to the model.
510+
if format is None and tool_calls:
511+
result.tool_calls = self._extract_model_tool_requests(tools, decoded_result)
512+
513+
parsed_result = self.formatter.parse(action, result)
514+
if generate_logs is not None:
515+
assert isinstance(generate_logs, list)
516+
generate_log = GenerateLog()
517+
generate_log.prompt = ctx_as_conversation
518+
generate_log.backend = f"hf::{self.model_id!s}"
519+
generate_log.model_options = model_options
520+
generate_log.date = datetime.datetime.now()
521+
generate_log.model_output = decoded_result
522+
generate_log.extra = {
523+
"format": format,
524+
"tools_available": tools,
525+
"tools_called": result.tool_calls,
526+
"seed": seed,
527+
}
528+
generate_log.action = action
529+
generate_log.result = parsed_result
530+
generate_logs.append(generate_log)
531+
return parsed_result
532+
275533
def _generate_from_context_standard(
276534
self,
277535
action: Component | CBlock,

0 commit comments

Comments
 (0)