-
Notifications
You must be signed in to change notification settings - Fork 70
Initial cache fixes #192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial cache fixes #192
Changes from all commits
6809f05
5e7834d
27fae1c
cce5927
0a71d2a
71f1d53
559d054
395cafb
d8240d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from typing import List, Optional, Any | ||
|
||
from mlx_lm.models.cache import RotatingKVCache, KVCache | ||
import mlx.nn as nn | ||
|
||
|
||
class AlwaysTrimmableKVCache(RotatingKVCache): | ||
"""A KV cache that can always be trimmed. | ||
|
||
The MLX-LM implementation of the RotatingKVCache does not allow trimming | ||
the cache once the maximum KV size has been exceeded, which results in | ||
the cache being nuked every time this happens. This forces the entire context | ||
to be reprocessed regularly, which is not ideal for performance. This KV cache | ||
allows trimming the cache at any time, which circumvents this issue. | ||
See https://github.com/lmstudio-ai/mlx-engine/issues/177 for more details. | ||
""" | ||
|
||
def trim(self, n) -> int: | ||
# trim must not respect keep: we always receive some value for keep, but | ||
# when initially processing the prompt, it may be that the common prefix | ||
# is shorter than keep. in that case we must trim to the common prefix length, | ||
# which violates keep. keep is only used for the cache rotation when exceeding | ||
# the context length mid-generation to ensure we don't lose the common prefix. | ||
n = min(self.offset, n) | ||
if n <= 0: | ||
return 0 | ||
|
||
# put us back into the state before the circular buffer is full | ||
self.keys = self._temporal_order(self.keys) | ||
self.values = self._temporal_order(self.values) | ||
|
||
new_length = max(self.keys.shape[2] - n, 0) | ||
self.keys = self.keys[..., :new_length, :] | ||
self.values = self.values[..., :new_length, :] | ||
|
||
self.offset = new_length | ||
self._idx = new_length | ||
return n | ||
|
||
def set_keep(self, keep): | ||
# kv must be in temporal order, else we will keep the wrong thing | ||
if self.keys is not None: | ||
self.keys = self._temporal_order(self.keys) | ||
if self.values is not None: | ||
self.values = self._temporal_order(self.values) | ||
self.keep = keep | ||
|
||
def is_trimmable(self) -> bool: | ||
return True | ||
|
||
|
||
def make_prompt_cache( | ||
model: nn.Module, | ||
max_kv_size: Optional[int] = None, | ||
keep: int = 4, | ||
) -> List[Any]: | ||
""" | ||
Construct the model's cache for use in generation. | ||
This function will defer the cache construction to the model if it has a | ||
``make_cache`` method, otherwise it will make a default KV cache. | ||
|
||
See https://github.com/ml-explore/mlx-lm/blob/fd9b1909636d634ac2b848248b05939c9fbfbe19/mlx_lm/models/cache.py#L10 | ||
for the MLX-LM implementation. This is a temporary extension to support more flexible | ||
trimming than MLX-LM's original RotatingKVCache. | ||
|
||
Args: | ||
model (nn.Module): The language model. | ||
max_kv_size (Optional[int]): If provided and the model does not have a | ||
``make_cache`` method, a ``AlwaysTrimmableKVCache`` is used with a maximum | ||
size of ``max_kv_size`` | ||
""" | ||
if hasattr(model, "make_cache"): | ||
return model.make_cache() | ||
num_layers = len(model.layers) | ||
if max_kv_size is not None: | ||
return [ | ||
AlwaysTrimmableKVCache(max_size=max_kv_size, keep=keep) | ||
for _ in range(num_layers) | ||
] | ||
else: | ||
return [KVCache() for _ in range(num_layers)] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
from typing import List, Optional, Any | ||
|
||
from mlx_engine.logging import log_info, log_warn, log_error | ||
from mlx_engine.cache import make_prompt_cache | ||
from mlx_lm.models.cache import ( | ||
make_prompt_cache, | ||
trim_prompt_cache, | ||
can_trim_prompt_cache, | ||
RotatingKVCache, | ||
) | ||
from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache | ||
import mlx.core as mx | ||
|
@@ -26,6 +27,7 @@ def __init__( | |
kv_bits: Optional[int] = None, | ||
kv_group_size: Optional[int] = None, | ||
quantized_kv_start: Optional[int] = None, | ||
keep: int = 4, | ||
): | ||
""" | ||
Initialize the CacheWrapper. | ||
|
@@ -36,7 +38,9 @@ def __init__( | |
""" | ||
# utilize a simple ordered list of tokens processed so far for cache invalidation checking | ||
self.tokens: Optional[mx.array] = None | ||
self.cache: List[Any] = make_prompt_cache(model, max_kv_size) | ||
self.keep = keep | ||
self.cache: List[Any] = make_prompt_cache(model, max_kv_size, keep) | ||
self.is_rotating = all(isinstance(c, RotatingKVCache) for c in self.cache) | ||
self.model = model | ||
self.draft_model: Optional[nn.Module] = None | ||
self.max_kv_size = max_kv_size | ||
|
@@ -115,7 +119,11 @@ def _get_unprocessed_tokens( | |
message=f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: " | ||
"Cache is not trimmable. Clearing the cache instead.", | ||
) | ||
self.cache = make_prompt_cache(self.model, self.max_kv_size) | ||
self.cache = make_prompt_cache( | ||
self.model, | ||
max_kv_size=self.max_kv_size if self.is_rotating else None, | ||
keep=self.keep, | ||
) | ||
self.tokens = prompt_tokens | ||
return self.tokens | ||
tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) | ||
|
@@ -126,7 +134,11 @@ def _get_unprocessed_tokens( | |
message=f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected " | ||
" ({num_tokens_to_trim}). Clearing the cache.", | ||
) | ||
self.cache = make_prompt_cache(self.model, self.max_kv_size) | ||
self.cache = make_prompt_cache( | ||
self.model, | ||
max_kv_size=self.max_kv_size if self.is_rotating else None, | ||
keep=self.keep, | ||
) | ||
self.tokens = prompt_tokens | ||
return self.tokens | ||
log_info( | ||
|
@@ -221,9 +233,11 @@ def set_draft_model(self, draft_model: nn.Module): | |
message="Clearing current prompt cache and adding draft model to the cache", | ||
) | ||
self.tokens = None | ||
self.cache: List[Any] = make_prompt_cache(self.model) | ||
self.cache: List[Any] = make_prompt_cache(self.model, keep=self.keep) | ||
# the above will never return a rotating cache since there is no max_kv_size set | ||
self.is_rotating = False | ||
if draft_model is not None: | ||
self.cache += make_prompt_cache(draft_model) | ||
self.cache += make_prompt_cache(draft_model, keep=self.keep) | ||
self.draft_model = draft_model | ||
|
||
def unset_draft_model(self): | ||
|
@@ -239,6 +253,7 @@ def update_cache( | |
prompt_progress_callback, | ||
*, | ||
num_tokens_to_exclude: int = 1, | ||
keep: int = 4, | ||
christian-lms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> mx.array: | ||
""" | ||
Set up the KV cache for the next generation. | ||
|
@@ -248,6 +263,7 @@ def update_cache( | |
prompt_tokens (mx.array): The prompt tokens. | ||
prompt_progress_callback (Callable): A callback function to report prompt processing progress. | ||
num_tokens_to_exclude (int): The number of tokens that should not be added to the cache. | ||
keep (int): The number of tokens to always keep in the prefix of the prompt cache. | ||
|
||
Returns: | ||
mx.array: The prompt tokens to be used for the next generation. | ||
|
@@ -257,6 +273,12 @@ def update_cache( | |
def prompt_progress_callback(x): | ||
return None | ||
|
||
# update keep tracking | ||
self.keep = keep | ||
for cache in self.cache: | ||
if hasattr(cache, "set_keep"): | ||
cache.set_keep(keep) | ||
|
||
num_tokens_to_exclude = max(num_tokens_to_exclude, 1) | ||
prompt_tokens = self._get_unprocessed_tokens( | ||
prompt_tokens, num_tokens_to_exclude | ||
|
@@ -296,5 +318,20 @@ def prompt_progress_callback(x): | |
def record_generated_token(self, token): | ||
""" | ||
Add the generated token to the token list, so that we can map the token to the KV cache. | ||
|
||
Also loop when the cache does so that we accurately track what's in cache, if we're using | ||
a RotatingKVCache or subclass of such. See the rotation implemented by MLX-LM here: | ||
https://github.com/ml-explore/mlx-lm/blob/fd9b1909636d634ac2b848248b05939c9fbfbe19/mlx_lm/models/cache.py#L371 | ||
""" | ||
# this behavior is common to rolling window (n_keep = 0) and truncate middle | ||
# (n_keep > 0), and we should never get here with stop at max | ||
if ( | ||
self.max_kv_size is not None | ||
and self.is_rotating | ||
Comment on lines
+329
to
+330
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there cases where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No: almost all cases where a model has a custom KV cache are recurrent/hybrid models where they have to use |
||
and len(self.tokens) >= self.max_kv_size | ||
): | ||
# rotate the token tracking buffer | ||
self.tokens = mx.concat( | ||
[self.tokens[: self.keep], self.tokens[self.keep + 1 :]] | ||
) | ||
self.tokens = mx.concat([self.tokens, mx.array([token])]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import unittest | ||
import mlx.core as mx | ||
from copy import deepcopy | ||
from mlx_engine.cache import AlwaysTrimmableKVCache | ||
|
||
|
||
class TestCache(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
"""Set up test resources that will be shared across all test methods""" | ||
cls.kv_head_dim = 4 | ||
cls.bsz = 1 | ||
cls.n_kv_heads = 1 | ||
|
||
@classmethod | ||
def make_random_kv(cls, seqlen: int): | ||
"""Helper method to make a random key/value tensor of the right shape""" | ||
return mx.random.normal( | ||
(cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), | ||
scale=1.0, | ||
dtype=mx.float32, | ||
) | ||
|
||
def assertArrEqual(self, a: mx.array, b: mx.array): | ||
"""Assert that two tensors are equal over the sequence length dimension""" | ||
self.assertEqual(a.shape, b.shape) | ||
self.assertTrue(mx.allclose(a, b), "Tensors are not equal") | ||
|
||
def add_random_to_cache( | ||
self, cache: AlwaysTrimmableKVCache, seqlen: int | ||
) -> mx.array: | ||
"""Add random values to the cache and return them""" | ||
base_kv = self.make_random_kv(seqlen) | ||
# base_kv is *assigned* to cache.keys/cache.values so returning base_kv | ||
# would return a reference to cache.keys, which is pointless. so copy it | ||
reference = deepcopy(base_kv) | ||
cache.update_and_fetch(base_kv, base_kv) | ||
return reference |
Uh oh!
There was an error while loading. Please reload this page.