@@ -272,10 +272,20 @@ def _generate_from_context_alora(
272272 ),
273273 )
274274
275- _cached_blocks = {}
276- _cached_toks = {}
277-
278- def _generate_from_context_with_kv_cache (
275+ _cached_blocks : dict [str , DynamicCache ] = dict ()
276+
277+ def _make_dc_cache (self , toks , ** model_options ):
278+ dc = DynamicCache ()
279+ with torch .no_grad ():
280+ dc = self ._model (
281+ toks ["input_ids" ].to (self ._device ),
282+ attention_mask = toks ["attention_mask" ].to (self ._device ),
283+ past_key_values = dc ,
284+ ** model_options ,
285+ ).past_key_values
286+ return dc
287+
288+ def _generate_from_context_with_kv_cache ( # noqa: C901
279289 self ,
280290 action : Component | CBlock ,
281291 ctx : Context ,
@@ -372,9 +382,16 @@ def _generate_from_context_with_kv_cache(
372382 for c in linearized_ctx :
373383 match c :
374384 case CBlock () if c .cache :
375- if c .value not in self ._cached_blocks :
376- FancyLogger .get_logger ().info (f"Caching { hash (c .value )} " )
377- tokens = self ._tokenizer (c .value )
385+ assert c .value is not None
386+ if c .value in self ._cached_blocks :
387+ FancyLogger .get_logger ().info (
388+ f"KV CACHE HIT for: { hash (c .value )} ({ c .value [:3 ]} ..{ c .value [- 3 :]} )" # type: ignore
389+ )
390+ else :
391+ FancyLogger .get_logger ().debug (
392+ f"HF backend is caching a CBlock with hashed contents: { hash (c .value )} ({ c .value [:3 ]} ..{ c .value [- 3 :]} )"
393+ )
394+ tokens = self ._tokenizer (c .value , return_tensors = "pt" )
378395 dc = DynamicCache ()
379396 with torch .no_grad ():
380397 dc = self ._model (
@@ -383,77 +400,97 @@ def _generate_from_context_with_kv_cache(
383400 self ._device
384401 ), # type: ignore
385402 past_key_values = dc ,
403+ use_cache = True ,
386404 ).past_key_values
387- legacy_cache = dc .to_legacy_cache ()
388- self ._cached_blocks [c .value ] = legacy_cache
389- self ._cached_toks [c .value ] = tokens
405+ self ._cached_blocks [c .value ] = dc
390406 cached_block_keys .append (c .value )
391407 case _:
392408 continue
393409
394- # 3. apply the chat template without tokenization.
410+ # 3. apply the chat template WITHOUT tokenization.
411+ # Doing this without tokenization and then gluing together the tokens is necessary because
412+ # things that KV cache together must tokenize together.
395413 input_text = self ._tokenizer .apply_chat_template ( # type: ignore
396414 ctx_as_conversation ,
397415 tools = convert_tools_to_json (tools ), # type: ignore
398416 ** self ._make_backend_specific_and_remove (model_options ),
399417 tokenize = False ,
400418 )
401419
402- # 4. split on cache hits
403- parts : list [str | tuple [DynamicCache , Any ]] = [input_text ]
420+ # 4. split the input_text back up again, re-using DC where it exists.
421+ str_parts = []
422+ tok_parts = []
423+ dc_parts = []
424+ current_suffix = input_text
404425 for key in cached_block_keys :
405- next_split = parts .pop ()
406- parts_split = next_split .split (key )
407- assert len (parts_split ) == 2 , (
426+ assert key is not None , (
427+ "Some input CBlock must not have bee ncomputed yet? The error comes far before this line."
428+ )
429+ assert key in current_suffix , (
430+ "Could happen but would be rare. related to the other assert in this block."
431+ )
432+ parts = current_suffix .split (key ) # type: ignore
433+ assert len (parts ) == 2 , (
408434 "Known issue: cached substring might occur more than once. We need to handle this situation earlier. Notice if this happens and keep a count."
409435 )
410- parts .append (parts_split [0 ])
411- parts .append ((self ._cached_blocks [key ], self ._cached_toks [key ]))
412- parts .append (parts_split [1 ])
413-
414- # 5. prefill + smash together everything.
415- prefilled : Any | None = None
416- parts_tokens : list [Any ] = []
417- for part in parts :
418- if type (part ) is str :
419- part_toks = self ._tokenizer (
420- part ,
421- return_tensors = "pt" ,
422- ** self ._make_backend_specific_and_remove (model_options ),
423- )
424- parts_tokens .append (part_toks )
425- part_legacy_cache = kv_block_helpers .tokens_to_legacy_cache (
426- self ._model , self ._device , part_toks
427- )
428- prefilled = (
429- part_legacy_cache
430- if prefilled is None
431- else kv_block_helpers .legacy_cache_smash (
432- prefilled , part_legacy_cache
433- )
434- )
435- else :
436- parts_tokens .append (part [1 ])
437- prefilled = (
438- part [0 ]
439- if prefilled is None
440- else kv_block_helpers .legacy_cache_smash (
441- prefilled , part_legacy_cache
442- )
436+ prefix , suffix = parts
437+ # Add the prefix, if any, to str+tok+dc parts.
438+ if prefix != "" :
439+ FancyLogger .get_logger ().debug (
440+ f"Doing a forward pass on uncached block which is prefix to a cached CBlock: { prefix [:3 ]} .{ len (prefix )} .{ prefix [- 3 :]} "
443441 )
442+ str_parts .append (prefix )
443+ tok_parts .append (self ._tokenizer (prefix , return_tensors = "pt" ))
444+ dc_parts .append (self ._make_dc_cache (tok_parts [- 1 ]))
445+ # Add the cached CBlock to str+tok+dc parts.
446+ FancyLogger .get_logger ().debug (
447+ f"Replacing a substring with previously computed/retrieved cache with hahs value { hash (key )} ({ key [:3 ]} ..{ key [- 3 :]} )"
448+ )
449+ # str_parts.append(key)
450+ # tok_parts.append(self._tokenizer(key, return_tensors="pt"))
451+ # dc_parts.append(self._make_dc_cache(tok_parts[-1])) # TODO this is wrong.
452+ str_parts .append (key )
453+ tok_parts .append (self ._tokenizer (key , return_tensors = "pt" ))
454+ dc_parts .append (self ._cached_blocks [key ])
455+ # set the suffix for the next loop iteration.
456+ current_suffix = suffix
457+ # "base" case: the final suffix.
458+ if current_suffix != "" :
459+ FancyLogger .get_logger ().debug ( # type: ignore
460+ f"Doing a forward pass on final suffix, an uncached block: { current_suffix [:3 ]} .{ len (current_suffix )} .{ current_suffix [- 3 :]} " # type: ignore
461+ ) # type: ignore
462+ str_parts .append (current_suffix )
463+ tok_parts .append (self ._tokenizer (current_suffix , return_tensors = "pt" ))
464+ dc_parts .append (self ._make_dc_cache (tok_parts [- 1 ]))
444465
445- # also smash together the tokens.
446- input_ids = torch .cat ([toks ["input_ids" ] for toks in parts_tokens ], dim = 1 )
466+ # Smash together the caches, the input_ids, and the attention masks.
467+ assert "" .join (str_parts ) == input_text , (
468+ "Should've ended up with the same input text!"
469+ )
470+ input_ids = torch .cat ([toks ["input_ids" ] for toks in tok_parts ], dim = 1 )
471+ attention_mask = torch .cat (
472+ [toks ["attention_mask" ] for toks in tok_parts ], dim = 1
473+ )
474+ assert input_ids .shape == attention_mask .shape
475+ merged_cache : DynamicCache = kv_block_helpers .merge_dynamic_caches (dc_parts )
476+ # TODO: also assert that the merged cached is the correct shape given the input_ids and attention_mask shapes.
477+
478+ # rewind merged cache by 1 for safety.
479+ merged_cache .crop (- 1 )
447480
448481 if format is None :
449482 chat_output = self ._model .generate ( # type: ignore
450- input_ids ,
483+ input_ids .to (self ._device ),
484+ attention_mask = attention_mask .to (self ._device ),
485+ use_cache = True ,
486+ past_key_values = merged_cache ,
451487 return_dict_in_generate = True ,
452488 output_scores = True ,
453489 ** self ._make_backend_specific_and_remove (model_options ),
454490 ) # type: ignore
455491
456492 else :
493+ raise NotImplementedError ("Copy implementation from above." )
457494 # outlines.generate.json always parses the resulting json into a python dict.
458495 # We however want to keep it as a json string for later storing it in ModelOutputThunk
459496 schema : dict [str , Any ] = format .model_json_schema ()
0 commit comments