Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions mlx_engine/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import List, Optional, Any

from mlx_lm.models.cache import RotatingKVCache, KVCache
import mlx.nn as nn


class AlwaysTrimmableKVCache(RotatingKVCache):
"""A KV cache that can always be trimmed.

The MLX-LM implementation of the RotatingKVCache does not allow trimming
the cache once the maximum KV size has been exceeded, which results in
the cache being nuked every time this happens. This forces the entire context
to be reprocessed regularly, which is not ideal for performance. This KV cache
allows trimming the cache at any time, which circumvents this issue.
See https://github.com/lmstudio-ai/mlx-engine/issues/177 for more details.
"""

def trim(self, n) -> int:
# trim must not respect keep: we always receive some value for keep, but
# when initially processing the prompt, it may be that the common prefix
# is shorter than keep. in that case we must trim to the common prefix length,
# which violates keep. keep is only used for the cache rotation when exceeding
# the context length mid-generation to ensure we don't lose the common prefix.
n = min(self.offset, n)
if n <= 0:
return 0

# put us back into the state before the circular buffer is full
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)

new_length = max(self.keys.shape[2] - n, 0)
self.keys = self.keys[..., :new_length, :]
self.values = self.values[..., :new_length, :]

self.offset = new_length
self._idx = new_length
return n

def set_keep(self, keep):
# kv must be in temporal order, else we will keep the wrong thing
if self.keys is not None:
self.keys = self._temporal_order(self.keys)
if self.values is not None:
self.values = self._temporal_order(self.values)
self.keep = keep

def is_trimmable(self) -> bool:
return True


def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
keep: int = 4,
) -> List[Any]:
"""
Construct the model's cache for use in generation.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.

See https://github.com/ml-explore/mlx-lm/blob/fd9b1909636d634ac2b848248b05939c9fbfbe19/mlx_lm/models/cache.py#L10
for the MLX-LM implementation. This is a temporary extension to support more flexible
trimming than MLX-LM's original RotatingKVCache.

Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``AlwaysTrimmableKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
AlwaysTrimmableKVCache(max_size=max_kv_size, keep=keep)
for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
49 changes: 43 additions & 6 deletions mlx_engine/cache_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Optional, Any

from mlx_engine.logging import log_info, log_warn, log_error
from mlx_engine.cache import make_prompt_cache
from mlx_lm.models.cache import (
make_prompt_cache,
trim_prompt_cache,
can_trim_prompt_cache,
RotatingKVCache,
)
from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache
import mlx.core as mx
Expand All @@ -26,6 +27,7 @@ def __init__(
kv_bits: Optional[int] = None,
kv_group_size: Optional[int] = None,
quantized_kv_start: Optional[int] = None,
keep: int = 4,
):
"""
Initialize the CacheWrapper.
Expand All @@ -36,7 +38,9 @@ def __init__(
"""
# utilize a simple ordered list of tokens processed so far for cache invalidation checking
self.tokens: Optional[mx.array] = None
self.cache: List[Any] = make_prompt_cache(model, max_kv_size)
self.keep = keep
self.cache: List[Any] = make_prompt_cache(model, max_kv_size, keep)
self.is_rotating = all(isinstance(c, RotatingKVCache) for c in self.cache)
self.model = model
self.draft_model: Optional[nn.Module] = None
self.max_kv_size = max_kv_size
Expand Down Expand Up @@ -115,7 +119,11 @@ def _get_unprocessed_tokens(
message=f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: "
"Cache is not trimmable. Clearing the cache instead.",
)
self.cache = make_prompt_cache(self.model, self.max_kv_size)
self.cache = make_prompt_cache(
self.model,
max_kv_size=self.max_kv_size if self.is_rotating else None,
keep=self.keep,
)
self.tokens = prompt_tokens
return self.tokens
tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim)
Expand All @@ -126,7 +134,11 @@ def _get_unprocessed_tokens(
message=f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected "
" ({num_tokens_to_trim}). Clearing the cache.",
)
self.cache = make_prompt_cache(self.model, self.max_kv_size)
self.cache = make_prompt_cache(
self.model,
max_kv_size=self.max_kv_size if self.is_rotating else None,
keep=self.keep,
)
self.tokens = prompt_tokens
return self.tokens
log_info(
Expand Down Expand Up @@ -221,9 +233,11 @@ def set_draft_model(self, draft_model: nn.Module):
message="Clearing current prompt cache and adding draft model to the cache",
)
self.tokens = None
self.cache: List[Any] = make_prompt_cache(self.model)
self.cache: List[Any] = make_prompt_cache(self.model, keep=self.keep)
# the above will never return a rotating cache since there is no max_kv_size set
self.is_rotating = False
if draft_model is not None:
self.cache += make_prompt_cache(draft_model)
self.cache += make_prompt_cache(draft_model, keep=self.keep)
self.draft_model = draft_model

def unset_draft_model(self):
Expand All @@ -239,6 +253,7 @@ def update_cache(
prompt_progress_callback,
*,
num_tokens_to_exclude: int = 1,
keep: int = 4,
) -> mx.array:
"""
Set up the KV cache for the next generation.
Expand All @@ -248,6 +263,7 @@ def update_cache(
prompt_tokens (mx.array): The prompt tokens.
prompt_progress_callback (Callable): A callback function to report prompt processing progress.
num_tokens_to_exclude (int): The number of tokens that should not be added to the cache.
keep (int): The number of tokens to always keep in the prefix of the prompt cache.

Returns:
mx.array: The prompt tokens to be used for the next generation.
Expand All @@ -257,6 +273,12 @@ def update_cache(
def prompt_progress_callback(x):
return None

# update keep tracking
self.keep = keep
for cache in self.cache:
if hasattr(cache, "set_keep"):
cache.set_keep(keep)

num_tokens_to_exclude = max(num_tokens_to_exclude, 1)
prompt_tokens = self._get_unprocessed_tokens(
prompt_tokens, num_tokens_to_exclude
Expand Down Expand Up @@ -296,5 +318,20 @@ def prompt_progress_callback(x):
def record_generated_token(self, token):
"""
Add the generated token to the token list, so that we can map the token to the KV cache.

Also loop when the cache does so that we accurately track what's in cache, if we're using
a RotatingKVCache or subclass of such. See the rotation implemented by MLX-LM here:
https://github.com/ml-explore/mlx-lm/blob/fd9b1909636d634ac2b848248b05939c9fbfbe19/mlx_lm/models/cache.py#L371
"""
# this behavior is common to rolling window (n_keep = 0) and truncate middle
# (n_keep > 0), and we should never get here with stop at max
if (
self.max_kv_size is not None
and self.is_rotating
Comment on lines +329 to +330
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there cases where self.max_kv_size is not None and it's not a rotating cache? Probably when a model has a custom kv cache?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No: almost all cases where a model has a custom KV cache are recurrent/hybrid models where they have to use MambaKVCaches for certain layers. There are a few architectures that override to force certain other combinations of caches, but all of them are combinations of the MLX default caches, and no model implements its own custom KV cache. So this is comprehensive. We discussed how to handle this and it seems like this is the best option since the cache itself doesn't expose an API for this.

and len(self.tokens) >= self.max_kv_size
):
# rotate the token tracking buffer
self.tokens = mx.concat(
[self.tokens[: self.keep], self.tokens[self.keep + 1 :]]
)
self.tokens = mx.concat([self.tokens, mx.array([token])])
2 changes: 2 additions & 0 deletions mlx_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def create_generator(
max_tokens: Optional[int] = 10000000,
speculative_decoding_toggle: Optional[bool] = None,
num_draft_tokens: Optional[int] = None,
keep: int = 4,
) -> Iterator[GenerationResult]:
"""
Create a generator that streams text generation results from the model.
Expand Down Expand Up @@ -218,6 +219,7 @@ def create_generator(
prompt_progress_callback,
generate_args,
speculative_decoding_toggle,
keep=keep,
)
if draft_model is None:
# input embeddings not yet supported for speculative decoding in mlx-lm
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def process_prompt(
prompt_progress_callback,
generate_args,
speculative_decoding_toggle: Optional[bool] = None,
keep: int = 4,
) -> Tuple[mx.array, Optional[mx.array]]:
### TEXT-ONLY PROCESS_PROMPT ###
is_text_only_processing = images_b64 is None or len(images_b64) == 0
Expand All @@ -160,6 +161,7 @@ def process_prompt(
self.draft_model,
speculative_decoding_toggle,
prompt_progress_callback,
keep=keep,
), None
### WITH IMAGES PROMPT PROCESSING ###s
if self.vision_add_on is None:
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/utils/prompt_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def process_prompt_text_only(
draft_model: Optional[nn.Module] = None,
speculative_decoding_toggle: Optional[bool] = None,
prompt_progress_callback: Optional[Callable[[float], None]] = None,
keep: int = 4,
):
if cache_wrapper is None:
raise ValueError("Cache wrapper is not initialized, cannot process prompt")
Expand All @@ -38,6 +39,7 @@ def process_prompt_text_only(
prompt_tokens = cache_wrapper.update_cache(
prompt_tokens,
prompt_progress_callback,
keep=keep,
)
generate_args["prompt_cache"] = cache_wrapper.cache
return prompt_tokens
38 changes: 38 additions & 0 deletions tests/test_cache_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest
import mlx.core as mx
from copy import deepcopy
from mlx_engine.cache import AlwaysTrimmableKVCache


class TestCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""Set up test resources that will be shared across all test methods"""
cls.kv_head_dim = 4
cls.bsz = 1
cls.n_kv_heads = 1

@classmethod
def make_random_kv(cls, seqlen: int):
"""Helper method to make a random key/value tensor of the right shape"""
return mx.random.normal(
(cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim),
scale=1.0,
dtype=mx.float32,
)

def assertArrEqual(self, a: mx.array, b: mx.array):
"""Assert that two tensors are equal over the sequence length dimension"""
self.assertEqual(a.shape, b.shape)
self.assertTrue(mx.allclose(a, b), "Tensors are not equal")

def add_random_to_cache(
self, cache: AlwaysTrimmableKVCache, seqlen: int
) -> mx.array:
"""Add random values to the cache and return them"""
base_kv = self.make_random_kv(seqlen)
# base_kv is *assigned* to cache.keys/cache.values so returning base_kv
# would return a reference to cache.keys, which is pointless. so copy it
reference = deepcopy(base_kv)
cache.update_and_fetch(base_kv, base_kv)
return reference
Loading