Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions mlx_engine/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import List, Optional, Any

from mlx_lm.models.cache import RotatingKVCache, KVCache
import mlx.nn as nn


class ShiftingKVCache(RotatingKVCache):
def trim(self, n) -> int:
# trim must not respect keep
n = min(self.offset, n)
if n <= 0:
return 0

# put us back into the state before the circular buffer is full
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)

new_length = max(self.keys.shape[2] - n, 0)
self.keys = self.keys[..., :new_length, :]
self.values = self.values[..., :new_length, :]

self.offset = new_length
self._idx = new_length
return n

def set_keep(self, keep):
# kv must be in temporal order, else we will keep the wrong thing
if self.keys is not None:
self.keys = self._temporal_order(self.keys)
if self.values is not None:
self.values = self._temporal_order(self.values)
self.keep = keep

def is_trimmable(self) -> bool:
return True


def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
keep: int = 4,
) -> List[Any]:
"""
Construct the model's cache for use in generation.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``ShiftingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
ShiftingKVCache(max_size=max_kv_size, keep=keep) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
39 changes: 29 additions & 10 deletions mlx_engine/cache_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from typing import List, Optional, Any

from mlx_engine.logging import log_info, log_warn, log_error
from mlx_lm.models.cache import (
make_prompt_cache,
trim_prompt_cache,
can_trim_prompt_cache,
)
from mlx_engine.cache import make_prompt_cache
from mlx_lm.models.cache import trim_prompt_cache, can_trim_prompt_cache
from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache
import mlx.core as mx
import mlx.nn as nn
Expand All @@ -26,6 +23,7 @@ def __init__(
kv_bits: Optional[int] = None,
kv_group_size: Optional[int] = None,
quantized_kv_start: Optional[int] = None,
keep: int = 4,
):
"""
Initialize the CacheWrapper.
Expand All @@ -36,7 +34,8 @@ def __init__(
"""
# utilize a simple ordered list of tokens processed so far for cache invalidation checking
self.tokens: Optional[mx.array] = None
self.cache: List[Any] = make_prompt_cache(model, max_kv_size)
self.keep = keep
self.cache: List[Any] = make_prompt_cache(model, max_kv_size, keep)
self.model = model
self.draft_model: Optional[nn.Module] = None
self.max_kv_size = max_kv_size
Expand Down Expand Up @@ -115,7 +114,9 @@ def _get_unprocessed_tokens(
message=f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: "
"Cache is not trimmable. Clearing the cache instead.",
)
self.cache = make_prompt_cache(self.model, self.max_kv_size)
self.cache = make_prompt_cache(
self.model, self.max_kv_size, keep=self.keep
)
self.tokens = prompt_tokens
return self.tokens
tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim)
Expand All @@ -126,7 +127,9 @@ def _get_unprocessed_tokens(
message=f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected "
" ({num_tokens_to_trim}). Clearing the cache.",
)
self.cache = make_prompt_cache(self.model, self.max_kv_size)
self.cache = make_prompt_cache(
self.model, self.max_kv_size, keep=self.keep
)
self.tokens = prompt_tokens
return self.tokens
log_info(
Expand Down Expand Up @@ -221,9 +224,9 @@ def set_draft_model(self, draft_model: nn.Module):
message="Clearing current prompt cache and adding draft model to the cache",
)
self.tokens = None
self.cache: List[Any] = make_prompt_cache(self.model)
self.cache: List[Any] = make_prompt_cache(self.model, keep=self.keep)
if draft_model is not None:
self.cache += make_prompt_cache(draft_model)
self.cache += make_prompt_cache(draft_model, keep=self.keep)
self.draft_model = draft_model

def unset_draft_model(self):
Expand All @@ -239,6 +242,7 @@ def update_cache(
prompt_progress_callback,
*,
num_tokens_to_exclude: int = 1,
keep: int = 4,
) -> mx.array:
"""
Set up the KV cache for the next generation.
Expand All @@ -248,6 +252,7 @@ def update_cache(
prompt_tokens (mx.array): The prompt tokens.
prompt_progress_callback (Callable): A callback function to report prompt processing progress.
num_tokens_to_exclude (int): The number of tokens that should not be added to the cache.
keep (int): The number of tokens to always keep in the prefix of the prompt cache.

Returns:
mx.array: The prompt tokens to be used for the next generation.
Expand All @@ -257,6 +262,12 @@ def update_cache(
def prompt_progress_callback(x):
return None

# update keep tracking
self.keep = keep
for cache in self.cache:
if hasattr(cache, "set_keep"):
cache.set_keep(keep)

num_tokens_to_exclude = max(num_tokens_to_exclude, 1)
prompt_tokens = self._get_unprocessed_tokens(
prompt_tokens, num_tokens_to_exclude
Expand Down Expand Up @@ -296,5 +307,13 @@ def prompt_progress_callback(x):
def record_generated_token(self, token):
"""
Add the generated token to the token list, so that we can map the token to the KV cache.

Also loop when the cache does so that we accurately track what's in cache.
"""
# this behavior is common to rolling window (n_keep = 0) and truncate middle
# (n_keep > 0), and we should never get here with stop at max
if len(self.tokens) >= self.max_kv_size:
self.tokens = mx.concat(
[self.tokens[: self.keep], self.tokens[self.keep + 1 :]]
)
self.tokens = mx.concat([self.tokens, mx.array([token])])
2 changes: 2 additions & 0 deletions mlx_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def create_generator(
max_tokens: Optional[int] = 10000000,
speculative_decoding_toggle: Optional[bool] = None,
num_draft_tokens: Optional[int] = None,
keep: int = 4,
) -> Iterator[GenerationResult]:
"""
Create a generator that streams text generation results from the model.
Expand Down Expand Up @@ -218,6 +219,7 @@ def create_generator(
prompt_progress_callback,
generate_args,
speculative_decoding_toggle,
keep=keep,
)
if draft_model is None:
# input embeddings not yet supported for speculative decoding in mlx-lm
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def process_prompt(
prompt_progress_callback,
generate_args,
speculative_decoding_toggle: Optional[bool] = None,
keep: int = 4,
) -> Tuple[mx.array, Optional[mx.array]]:
### TEXT-ONLY PROCESS_PROMPT ###
is_text_only_processing = images_b64 is None or len(images_b64) == 0
Expand All @@ -160,6 +161,7 @@ def process_prompt(
self.draft_model,
speculative_decoding_toggle,
prompt_progress_callback,
keep=keep,
), None
### WITH IMAGES PROMPT PROCESSING ###s
if self.vision_add_on is None:
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/utils/prompt_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def process_prompt_text_only(
draft_model: Optional[nn.Module] = None,
speculative_decoding_toggle: Optional[bool] = None,
prompt_progress_callback: Optional[Callable[[float], None]] = None,
keep: int = 4,
):
if cache_wrapper is None:
raise ValueError("Cache wrapper is not initialized, cannot process prompt")
Expand All @@ -38,6 +39,7 @@ def process_prompt_text_only(
prompt_tokens = cache_wrapper.update_cache(
prompt_tokens,
prompt_progress_callback,
keep=keep,
)
generate_args["prompt_cache"] = cache_wrapper.cache
return prompt_tokens
36 changes: 36 additions & 0 deletions tests/test_cache_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest
import mlx.core as mx
from copy import deepcopy
from mlx_engine.cache import ShiftingKVCache


class TestCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""Set up test resources that will be shared across all test methods"""
cls.kv_head_dim = 4
cls.bsz = 1
cls.n_kv_heads = 1

@classmethod
def make_random_kv(cls, seqlen: int):
"""Helper method to make a random key/value tensor of the right shape"""
return mx.random.normal(
(cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim),
scale=1.0,
dtype=mx.float32,
)

def assertArrEqual(self, a: mx.array, b: mx.array):
"""Assert that two tensors are equal over the sequence length dimension"""
self.assertEqual(a.shape, b.shape)
self.assertTrue(mx.allclose(a, b), "Tensors are not equal")

def add_random_to_cache(self, cache: ShiftingKVCache, seqlen: int) -> mx.array:
"""Add random values to the cache and return them"""
base_kv = self.make_random_kv(seqlen)
# base_kv is *assigned* to cache.keys/cache.values so returning base_kv
# would return a reference to cache.keys, which is pointless. so copy it
reference = deepcopy(base_kv)
cache.update_and_fetch(base_kv, base_kv)
return reference
Loading