-
Notifications
You must be signed in to change notification settings - Fork 70
Initial cache fixes #192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
christian-lms
wants to merge
9
commits into
lmstudio-ai:main
from
christian-lms:christian/first_cache_fixes
Closed
Initial cache fixes #192
Changes from 3 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
6809f05
coef
christian-lms 5e7834d
tests
christian-lms 27fae1c
please the linter
christian-lms cce5927
test that trim stops nuking
christian-lms 0a71d2a
address comments
christian-lms 71f1d53
fix draft models
christian-lms 559d054
permalink to rotation
christian-lms 395cafb
gate logic
christian-lms d8240d7
Update mlx_engine/cache.py
christian-lms File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from typing import List, Optional, Any | ||
|
||
from mlx_lm.models.cache import RotatingKVCache, KVCache | ||
import mlx.nn as nn | ||
|
||
|
||
class ShiftingKVCache(RotatingKVCache): | ||
def trim(self, n) -> int: | ||
# trim must not respect keep | ||
christian-lms marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
n = min(self.offset, n) | ||
if n <= 0: | ||
return 0 | ||
|
||
# put us back into the state before the circular buffer is full | ||
self.keys = self._temporal_order(self.keys) | ||
self.values = self._temporal_order(self.values) | ||
|
||
new_length = max(self.keys.shape[2] - n, 0) | ||
self.keys = self.keys[..., :new_length, :] | ||
self.values = self.values[..., :new_length, :] | ||
|
||
self.offset = new_length | ||
self._idx = new_length | ||
return n | ||
|
||
def set_keep(self, keep): | ||
# kv must be in temporal order, else we will keep the wrong thing | ||
if self.keys is not None: | ||
self.keys = self._temporal_order(self.keys) | ||
if self.values is not None: | ||
self.values = self._temporal_order(self.values) | ||
self.keep = keep | ||
|
||
def is_trimmable(self) -> bool: | ||
return True | ||
|
||
|
||
def make_prompt_cache( | ||
christian-lms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model: nn.Module, | ||
max_kv_size: Optional[int] = None, | ||
keep: int = 4, | ||
) -> List[Any]: | ||
""" | ||
Construct the model's cache for use in generation. | ||
This function will defer the cache construction to the model if it has a | ||
``make_cache`` method, otherwise it will make a default KV cache. | ||
Args: | ||
model (nn.Module): The language model. | ||
max_kv_size (Optional[int]): If provided and the model does not have a | ||
``make_cache`` method, a ``ShiftingKVCache`` is used with a maximum | ||
size of ``max_kv_size`` | ||
""" | ||
if hasattr(model, "make_cache"): | ||
return model.make_cache() | ||
num_layers = len(model.layers) | ||
if max_kv_size is not None: | ||
return [ | ||
ShiftingKVCache(max_size=max_kv_size, keep=keep) for _ in range(num_layers) | ||
] | ||
else: | ||
return [KVCache() for _ in range(num_layers)] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import unittest | ||
import mlx.core as mx | ||
from copy import deepcopy | ||
from mlx_engine.cache import ShiftingKVCache | ||
|
||
|
||
class TestCache(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
"""Set up test resources that will be shared across all test methods""" | ||
cls.kv_head_dim = 4 | ||
cls.bsz = 1 | ||
cls.n_kv_heads = 1 | ||
|
||
@classmethod | ||
def make_random_kv(cls, seqlen: int): | ||
"""Helper method to make a random key/value tensor of the right shape""" | ||
return mx.random.normal( | ||
(cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), | ||
scale=1.0, | ||
dtype=mx.float32, | ||
) | ||
|
||
def assertArrEqual(self, a: mx.array, b: mx.array): | ||
"""Assert that two tensors are equal over the sequence length dimension""" | ||
self.assertEqual(a.shape, b.shape) | ||
self.assertTrue(mx.allclose(a, b), "Tensors are not equal") | ||
|
||
def add_random_to_cache(self, cache: ShiftingKVCache, seqlen: int) -> mx.array: | ||
"""Add random values to the cache and return them""" | ||
base_kv = self.make_random_kv(seqlen) | ||
# base_kv is *assigned* to cache.keys/cache.values so returning base_kv | ||
# would return a reference to cache.keys, which is pointless. so copy it | ||
reference = deepcopy(base_kv) | ||
cache.update_and_fetch(base_kv, base_kv) | ||
return reference |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.