Skip to content

Commit ead3fe8

Browse files
committed
Adds KV cache smash.
1 parent a648405 commit ead3fe8

File tree

1 file changed

+89
-52
lines changed

1 file changed

+89
-52
lines changed

mellea/backends/huggingface.py

Lines changed: 89 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,20 @@ def _generate_from_context_alora(
272272
),
273273
)
274274

275-
_cached_blocks = {}
276-
_cached_toks = {}
277-
278-
def _generate_from_context_with_kv_cache(
275+
_cached_blocks: dict[str, DynamicCache] = dict()
276+
277+
def _make_dc_cache(self, toks, **model_options):
278+
dc = DynamicCache()
279+
with torch.no_grad():
280+
dc = self._model(
281+
toks["input_ids"].to(self._device),
282+
attention_mask=toks["attention_mask"].to(self._device),
283+
past_key_values=dc,
284+
**model_options,
285+
).past_key_values
286+
return dc
287+
288+
def _generate_from_context_with_kv_cache( # noqa: C901
279289
self,
280290
action: Component | CBlock,
281291
ctx: Context,
@@ -372,9 +382,16 @@ def _generate_from_context_with_kv_cache(
372382
for c in linearized_ctx:
373383
match c:
374384
case CBlock() if c.cache:
375-
if c.value not in self._cached_blocks:
376-
FancyLogger.get_logger().info(f"Caching {hash(c.value)}")
377-
tokens = self._tokenizer(c.value)
385+
assert c.value is not None
386+
if c.value in self._cached_blocks:
387+
FancyLogger.get_logger().info(
388+
f"KV CACHE HIT for: {hash(c.value)} ({c.value[:3]}..{c.value[-3:]})" # type: ignore
389+
)
390+
else:
391+
FancyLogger.get_logger().debug(
392+
f"HF backend is caching a CBlock with hashed contents: {hash(c.value)} ({c.value[:3]}..{c.value[-3:]})"
393+
)
394+
tokens = self._tokenizer(c.value, return_tensors="pt")
378395
dc = DynamicCache()
379396
with torch.no_grad():
380397
dc = self._model(
@@ -383,77 +400,97 @@ def _generate_from_context_with_kv_cache(
383400
self._device
384401
), # type: ignore
385402
past_key_values=dc,
403+
use_cache=True,
386404
).past_key_values
387-
legacy_cache = dc.to_legacy_cache()
388-
self._cached_blocks[c.value] = legacy_cache
389-
self._cached_toks[c.value] = tokens
405+
self._cached_blocks[c.value] = dc
390406
cached_block_keys.append(c.value)
391407
case _:
392408
continue
393409

394-
# 3. apply the chat template without tokenization.
410+
# 3. apply the chat template WITHOUT tokenization.
411+
# Doing this without tokenization and then gluing together the tokens is necessary because
412+
# things that KV cache together must tokenize together.
395413
input_text = self._tokenizer.apply_chat_template( # type: ignore
396414
ctx_as_conversation,
397415
tools=convert_tools_to_json(tools), # type: ignore
398416
**self._make_backend_specific_and_remove(model_options),
399417
tokenize=False,
400418
)
401419

402-
# 4. split on cache hits
403-
parts: list[str | tuple[DynamicCache, Any]] = [input_text]
420+
# 4. split the input_text back up again, re-using DC where it exists.
421+
str_parts = []
422+
tok_parts = []
423+
dc_parts = []
424+
current_suffix = input_text
404425
for key in cached_block_keys:
405-
next_split = parts.pop()
406-
parts_split = next_split.split(key)
407-
assert len(parts_split) == 2, (
426+
assert key is not None, (
427+
"Some input CBlock must not have bee ncomputed yet? The error comes far before this line."
428+
)
429+
assert key in current_suffix, (
430+
"Could happen but would be rare. related to the other assert in this block."
431+
)
432+
parts = current_suffix.split(key) # type: ignore
433+
assert len(parts) == 2, (
408434
"Known issue: cached substring might occur more than once. We need to handle this situation earlier. Notice if this happens and keep a count."
409435
)
410-
parts.append(parts_split[0])
411-
parts.append((self._cached_blocks[key], self._cached_toks[key]))
412-
parts.append(parts_split[1])
413-
414-
# 5. prefill + smash together everything.
415-
prefilled: Any | None = None
416-
parts_tokens: list[Any] = []
417-
for part in parts:
418-
if type(part) is str:
419-
part_toks = self._tokenizer(
420-
part,
421-
return_tensors="pt",
422-
**self._make_backend_specific_and_remove(model_options),
423-
)
424-
parts_tokens.append(part_toks)
425-
part_legacy_cache = kv_block_helpers.tokens_to_legacy_cache(
426-
self._model, self._device, part_toks
427-
)
428-
prefilled = (
429-
part_legacy_cache
430-
if prefilled is None
431-
else kv_block_helpers.legacy_cache_smash(
432-
prefilled, part_legacy_cache
433-
)
434-
)
435-
else:
436-
parts_tokens.append(part[1])
437-
prefilled = (
438-
part[0]
439-
if prefilled is None
440-
else kv_block_helpers.legacy_cache_smash(
441-
prefilled, part_legacy_cache
442-
)
436+
prefix, suffix = parts
437+
# Add the prefix, if any, to str+tok+dc parts.
438+
if prefix != "":
439+
FancyLogger.get_logger().debug(
440+
f"Doing a forward pass on uncached block which is prefix to a cached CBlock: {prefix[:3]}.{len(prefix)}.{prefix[-3:]}"
443441
)
442+
str_parts.append(prefix)
443+
tok_parts.append(self._tokenizer(prefix, return_tensors="pt"))
444+
dc_parts.append(self._make_dc_cache(tok_parts[-1]))
445+
# Add the cached CBlock to str+tok+dc parts.
446+
FancyLogger.get_logger().debug(
447+
f"Replacing a substring with previously computed/retrieved cache with hahs value {hash(key)} ({key[:3]}..{key[-3:]})"
448+
)
449+
# str_parts.append(key)
450+
# tok_parts.append(self._tokenizer(key, return_tensors="pt"))
451+
# dc_parts.append(self._make_dc_cache(tok_parts[-1])) # TODO this is wrong.
452+
str_parts.append(key)
453+
tok_parts.append(self._tokenizer(key, return_tensors="pt"))
454+
dc_parts.append(self._cached_blocks[key])
455+
# set the suffix for the next loop iteration.
456+
current_suffix = suffix
457+
# "base" case: the final suffix.
458+
if current_suffix != "":
459+
FancyLogger.get_logger().debug( # type: ignore
460+
f"Doing a forward pass on final suffix, an uncached block: {current_suffix[:3]}.{len(current_suffix)}.{current_suffix[-3:]}" # type: ignore
461+
) # type: ignore
462+
str_parts.append(current_suffix)
463+
tok_parts.append(self._tokenizer(current_suffix, return_tensors="pt"))
464+
dc_parts.append(self._make_dc_cache(tok_parts[-1]))
444465

445-
# also smash together the tokens.
446-
input_ids = torch.cat([toks["input_ids"] for toks in parts_tokens], dim=1)
466+
# Smash together the caches, the input_ids, and the attention masks.
467+
assert "".join(str_parts) == input_text, (
468+
"Should've ended up with the same input text!"
469+
)
470+
input_ids = torch.cat([toks["input_ids"] for toks in tok_parts], dim=1)
471+
attention_mask = torch.cat(
472+
[toks["attention_mask"] for toks in tok_parts], dim=1
473+
)
474+
assert input_ids.shape == attention_mask.shape
475+
merged_cache: DynamicCache = kv_block_helpers.merge_dynamic_caches(dc_parts)
476+
# TODO: also assert that the merged cached is the correct shape given the input_ids and attention_mask shapes.
477+
478+
# rewind merged cache by 1 for safety.
479+
merged_cache.crop(-1)
447480

448481
if format is None:
449482
chat_output = self._model.generate( # type: ignore
450-
input_ids,
483+
input_ids.to(self._device),
484+
attention_mask=attention_mask.to(self._device),
485+
use_cache=True,
486+
past_key_values=merged_cache,
451487
return_dict_in_generate=True,
452488
output_scores=True,
453489
**self._make_backend_specific_and_remove(model_options),
454490
) # type: ignore
455491

456492
else:
493+
raise NotImplementedError("Copy implementation from above.")
457494
# outlines.generate.json always parses the resulting json into a python dict.
458495
# We however want to keep it as a json string for later storing it in ModelOutputThunk
459496
schema: dict[str, Any] = format.model_json_schema()

0 commit comments

Comments
 (0)