|
26 | 26 | set_seed, |
27 | 27 | ) |
28 | 28 |
|
29 | | -from mellea.backends import BaseModelSubclass |
| 29 | +from mellea.backends import BaseModelSubclass, kv_block_helpers |
30 | 30 | from mellea.backends.aloras import Alora, AloraBackendMixin |
31 | 31 | from mellea.backends.cache import Cache, SimpleLRUCache |
32 | 32 | from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter |
@@ -272,6 +272,264 @@ def _generate_from_context_alora( |
272 | 272 | ), |
273 | 273 | ) |
274 | 274 |
|
| 275 | + _cached_blocks = {} |
| 276 | + _cached_toks = {} |
| 277 | + |
| 278 | + def _generate_from_context_with_kv_cache( |
| 279 | + self, |
| 280 | + action: Component | CBlock, |
| 281 | + ctx: Context, |
| 282 | + *, |
| 283 | + format: type[BaseModelSubclass] | None = None, |
| 284 | + model_options: dict[str, Any] = {}, |
| 285 | + generate_logs: list[GenerateLog] | None = None, |
| 286 | + tool_calls: bool = False, |
| 287 | + ) -> ModelOutputThunk: |
| 288 | + # Construct input. |
| 289 | + # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template. |
| 290 | + # Otherwise, we will linearize the context and treat it as a raw input. |
| 291 | + decoded_result: str | None = None |
| 292 | + if ctx.is_chat_context: |
| 293 | + linearized_ctx = ctx.render_for_generation() |
| 294 | + |
| 295 | + assert linearized_ctx is not None, ( |
| 296 | + "If ctx.is_chat_context, then the context should be linearizable." |
| 297 | + ) |
| 298 | + ctx_as_message_list: list[Message] = self.formatter.to_chat_messages( |
| 299 | + linearized_ctx |
| 300 | + ) |
| 301 | + # add action |
| 302 | + ctx_as_message_list.extend(self.formatter.to_chat_messages([action])) |
| 303 | + |
| 304 | + ctx_as_conversation = [ |
| 305 | + {"role": m.role, "content": m.content} for m in ctx_as_message_list |
| 306 | + ] |
| 307 | + |
| 308 | + # Check that we ddin't accidentally end up with CBlocks. |
| 309 | + for msg in ctx_as_conversation: |
| 310 | + for v in msg.values(): |
| 311 | + if "CBlock" in v: |
| 312 | + FancyLogger.get_logger().error( |
| 313 | + f"Found the string `CBlock` in what should've been a stringified context: {ctx_as_conversation}" |
| 314 | + ) |
| 315 | + |
| 316 | + # handle custom system prompts. It's important that we do this before the _parse_and_**clean**_model_options step. |
| 317 | + system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None) |
| 318 | + if system_prompt is not None: |
| 319 | + system_msg: dict[str, str] = { |
| 320 | + "role": "system", |
| 321 | + "content": system_prompt, |
| 322 | + } |
| 323 | + ctx_as_conversation.insert(0, system_msg) |
| 324 | + |
| 325 | + # Append tool call information if applicable. |
| 326 | + tools: dict[str, Callable] = dict() |
| 327 | + if tool_calls: |
| 328 | + if format: |
| 329 | + FancyLogger.get_logger().warning( |
| 330 | + f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" |
| 331 | + ) |
| 332 | + else: |
| 333 | + if isinstance(action, Component) and isinstance( |
| 334 | + action.format_for_llm(), TemplateRepresentation |
| 335 | + ): |
| 336 | + tools = get_tools_from_action(action) |
| 337 | + |
| 338 | + model_options_tools = model_options.get(ModelOption.TOOLS, None) |
| 339 | + if model_options_tools is not None: |
| 340 | + assert isinstance(model_options_tools, dict) |
| 341 | + for fn_name in model_options_tools: |
| 342 | + # invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools |
| 343 | + assert fn_name not in tools.keys(), ( |
| 344 | + f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action." |
| 345 | + ) |
| 346 | + # type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries. |
| 347 | + assert type(fn_name) is str, ( |
| 348 | + "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." |
| 349 | + ) |
| 350 | + assert callable(model_options_tools[fn_name]), ( |
| 351 | + "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." |
| 352 | + ) |
| 353 | + # Add the model_options tool to the existing set of tools. |
| 354 | + tools[fn_name] = model_options_tools[fn_name] |
| 355 | + |
| 356 | + seed = model_options.get(ModelOption.SEED, None) |
| 357 | + if seed is not None: |
| 358 | + set_seed(seed) |
| 359 | + |
| 360 | + # Explanation for code blocks inside of use_kv_cache checks: |
| 361 | + # 1. cache every CBlock that is marked with `cache=True` and store in _cached_blocks. |
| 362 | + # 2. Mark each "hit" by adding the string (tokenized?) value to `cached_block_keys`. |
| 363 | + # 3. apply the chat template (without?) tokenization |
| 364 | + # 4. split on cache hits |
| 365 | + # 5. prefill + smash together everything. |
| 366 | + # 6. generate |
| 367 | + |
| 368 | + # 1. cache every CBlock that is marked with `cache=True` and store in _cached_blocks. |
| 369 | + # AND |
| 370 | + # 2. Mark each "hit" by adding the string (tokenized?) value to `cached_block_keys`. |
| 371 | + cached_block_keys = [] |
| 372 | + for c in linearized_ctx: |
| 373 | + match c: |
| 374 | + case CBlock() if c.cache: |
| 375 | + if c.value not in self._cached_blocks: |
| 376 | + FancyLogger.get_logger().info(f"Caching {hash(c.value)}") |
| 377 | + tokens = self._tokenizer(c.value) |
| 378 | + dc = DynamicCache() |
| 379 | + with torch.no_grad(): |
| 380 | + dc = self._model( |
| 381 | + tokens["input_ids"].to(self._device), # type: ignore |
| 382 | + attention_mask=tokens["attention_mask"].to( |
| 383 | + self._device |
| 384 | + ), # type: ignore |
| 385 | + past_key_values=dc, |
| 386 | + ).past_key_values |
| 387 | + legacy_cache = dc.to_legacy_cache() |
| 388 | + self._cached_blocks[c.value] = legacy_cache |
| 389 | + self._cached_toks[c.value] = tokens |
| 390 | + cached_block_keys.append(c.value) |
| 391 | + case _: |
| 392 | + continue |
| 393 | + |
| 394 | + # 3. apply the chat template without tokenization. |
| 395 | + input_text = self._tokenizer.apply_chat_template( # type: ignore |
| 396 | + ctx_as_conversation, |
| 397 | + tools=convert_tools_to_json(tools), # type: ignore |
| 398 | + **self._make_backend_specific_and_remove(model_options), |
| 399 | + tokenize=False, |
| 400 | + ) |
| 401 | + |
| 402 | + # 4. split on cache hits |
| 403 | + parts: list[str | tuple[DynamicCache, Any]] = [input_text] |
| 404 | + for key in cached_block_keys: |
| 405 | + next_split = parts.pop() |
| 406 | + parts_split = next_split.split(key) |
| 407 | + assert len(parts_split) == 2, ( |
| 408 | + "Known issue: cached substring might occur more than once. We need to handle this situation earlier. Notice if this happens and keep a count." |
| 409 | + ) |
| 410 | + parts.append(parts_split[0]) |
| 411 | + parts.append((self._cached_blocks[key], self._cached_toks[key])) |
| 412 | + parts.append(parts_split[1]) |
| 413 | + |
| 414 | + # 5. prefill + smash together everything. |
| 415 | + prefilled: Any | None = None |
| 416 | + parts_tokens: list[Any] = [] |
| 417 | + for part in parts: |
| 418 | + if type(part) is str: |
| 419 | + part_toks = self._tokenizer( |
| 420 | + part, |
| 421 | + return_tensors="pt", |
| 422 | + **self._make_backend_specific_and_remove(model_options), |
| 423 | + ) |
| 424 | + parts_tokens.append(part_toks) |
| 425 | + part_legacy_cache = kv_block_helpers.tokens_to_legacy_cache( |
| 426 | + self._model, self._device, part_toks |
| 427 | + ) |
| 428 | + prefilled = ( |
| 429 | + part_legacy_cache |
| 430 | + if prefilled is None |
| 431 | + else kv_block_helpers.legacy_cache_smash( |
| 432 | + prefilled, part_legacy_cache |
| 433 | + ) |
| 434 | + ) |
| 435 | + else: |
| 436 | + parts_tokens.append(part[1]) |
| 437 | + prefilled = ( |
| 438 | + part[0] |
| 439 | + if prefilled is None |
| 440 | + else kv_block_helpers.legacy_cache_smash( |
| 441 | + prefilled, part_legacy_cache |
| 442 | + ) |
| 443 | + ) |
| 444 | + |
| 445 | + # also smash together the tokens. |
| 446 | + input_ids = torch.cat([toks["input_ids"] for toks in parts_tokens], dim=1) |
| 447 | + |
| 448 | + if format is None: |
| 449 | + chat_output = self._model.generate( # type: ignore |
| 450 | + input_ids, |
| 451 | + return_dict_in_generate=True, |
| 452 | + output_scores=True, |
| 453 | + **self._make_backend_specific_and_remove(model_options), |
| 454 | + ) # type: ignore |
| 455 | + |
| 456 | + else: |
| 457 | + # outlines.generate.json always parses the resulting json into a python dict. |
| 458 | + # We however want to keep it as a json string for later storing it in ModelOutputThunk |
| 459 | + schema: dict[str, Any] = format.model_json_schema() |
| 460 | + schema_json: str = json.dumps(schema) |
| 461 | + regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( |
| 462 | + schema_json |
| 463 | + ) |
| 464 | + |
| 465 | + from outlines.models.transformers import TransformerTokenizer |
| 466 | + from outlines.processors import RegexLogitsProcessor |
| 467 | + from transformers import LogitsProcessorList |
| 468 | + |
| 469 | + chat_output = self._model.generate( # type: ignore |
| 470 | + input_ids, |
| 471 | + return_dict_in_generate=True, |
| 472 | + output_scores=True, |
| 473 | + logits_processor=LogitsProcessorList( |
| 474 | + [ |
| 475 | + RegexLogitsProcessor( |
| 476 | + regex_str, |
| 477 | + tokenizer=TransformerTokenizer(self._tokenizer), |
| 478 | + ) |
| 479 | + ] |
| 480 | + ), |
| 481 | + **self._make_backend_specific_and_remove(model_options), |
| 482 | + ) |
| 483 | + |
| 484 | + decoded_result = self._tokenizer.decode( |
| 485 | + chat_output.sequences[0, input_ids.shape[1] :], skip_special_tokens=True |
| 486 | + ) |
| 487 | + |
| 488 | + # Add an entry to the cache for ALora reuse. |
| 489 | + if self._use_caches: |
| 490 | + output_complete = chat_output.sequences[0] |
| 491 | + cache: DynamicCache = chat_output.past_key_values |
| 492 | + |
| 493 | + cache_info = HFAloraCacheInfo( |
| 494 | + kv_cache=cache, |
| 495 | + merged_token_ids=output_complete, |
| 496 | + merged_attention=torch.ones_like(output_complete).to(self._device), |
| 497 | + q_end=len(input_ids[0]), |
| 498 | + ) |
| 499 | + |
| 500 | + assert decoded_result is not None |
| 501 | + self.cache_put(decoded_result, cache_info) |
| 502 | + else: |
| 503 | + raise Exception("Does not yet support non-chat contexts.") |
| 504 | + |
| 505 | + assert decoded_result is not None |
| 506 | + |
| 507 | + result = ModelOutputThunk(value=decoded_result) |
| 508 | + |
| 509 | + # Only scan for tools if we are not doing structured decoding and tool calls were provided to the model. |
| 510 | + if format is None and tool_calls: |
| 511 | + result.tool_calls = self._extract_model_tool_requests(tools, decoded_result) |
| 512 | + |
| 513 | + parsed_result = self.formatter.parse(action, result) |
| 514 | + if generate_logs is not None: |
| 515 | + assert isinstance(generate_logs, list) |
| 516 | + generate_log = GenerateLog() |
| 517 | + generate_log.prompt = ctx_as_conversation |
| 518 | + generate_log.backend = f"hf::{self.model_id!s}" |
| 519 | + generate_log.model_options = model_options |
| 520 | + generate_log.date = datetime.datetime.now() |
| 521 | + generate_log.model_output = decoded_result |
| 522 | + generate_log.extra = { |
| 523 | + "format": format, |
| 524 | + "tools_available": tools, |
| 525 | + "tools_called": result.tool_calls, |
| 526 | + "seed": seed, |
| 527 | + } |
| 528 | + generate_log.action = action |
| 529 | + generate_log.result = parsed_result |
| 530 | + generate_logs.append(generate_log) |
| 531 | + return parsed_result |
| 532 | + |
275 | 533 | def _generate_from_context_standard( |
276 | 534 | self, |
277 | 535 | action: Component | CBlock, |
|
0 commit comments