diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py new file mode 100644 index 0000000..f22bc61 --- /dev/null +++ b/mlx_engine/cache.py @@ -0,0 +1,81 @@ +from typing import List, Optional, Any + +from mlx_lm.models.cache import RotatingKVCache, KVCache +import mlx.nn as nn + + +class AlwaysTrimmableKVCache(RotatingKVCache): + """A KV cache that can always be trimmed. + + The MLX-LM implementation of the RotatingKVCache does not allow trimming + the cache once the maximum KV size has been exceeded, which results in + the cache being nuked every time this happens. This forces the entire context + to be reprocessed regularly, which is not ideal for performance. This KV cache + allows trimming the cache at any time, which circumvents this issue. + See https://github.com/lmstudio-ai/mlx-engine/issues/177 for more details. + """ + + def trim(self, n) -> int: + # trim must not respect keep: we always receive some value for keep, but + # when initially processing the prompt, it may be that the common prefix + # is shorter than keep. in that case we must trim to the common prefix length, + # which violates keep. keep is only used for the cache rotation when exceeding + # the context length mid-generation to ensure we don't lose the common prefix. + n = min(self.offset, n) + if n <= 0: + return 0 + + # put us back into the state before the circular buffer is full + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + + new_length = max(self.keys.shape[2] - n, 0) + self.keys = self.keys[..., :new_length, :] + self.values = self.values[..., :new_length, :] + + self.offset = new_length + self._idx = new_length + return n + + def set_keep(self, keep): + # kv must be in temporal order, else we will keep the wrong thing + if self.keys is not None: + self.keys = self._temporal_order(self.keys) + if self.values is not None: + self.values = self._temporal_order(self.values) + self.keep = keep + + def is_trimmable(self) -> bool: + return True + + +def make_prompt_cache( + model: nn.Module, + max_kv_size: Optional[int] = None, + keep: int = 4, +) -> List[Any]: + """ + Construct the model's cache for use in generation. + This function will defer the cache construction to the model if it has a + ``make_cache`` method, otherwise it will make a default KV cache. + + See https://github.com/ml-explore/mlx-lm/blob/fd9b1909636d634ac2b848248b05939c9fbfbe19/mlx_lm/models/cache.py#L10 + for the MLX-LM implementation. This is a temporary extension to support more flexible + trimming than MLX-LM's original RotatingKVCache. + + Args: + model (nn.Module): The language model. + max_kv_size (Optional[int]): If provided and the model does not have a + ``make_cache`` method, a ``AlwaysTrimmableKVCache`` is used with a maximum + size of ``max_kv_size`` + """ + if hasattr(model, "make_cache"): + return model.make_cache() + num_layers = len(model.layers) + if max_kv_size is not None: + return [ + AlwaysTrimmableKVCache(max_size=max_kv_size, keep=keep) + for _ in range(num_layers) + ] + else: + return [KVCache() for _ in range(num_layers)] diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 36498fa..3fb13de 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -1,10 +1,11 @@ from typing import List, Optional, Any from mlx_engine.logging import log_info, log_warn, log_error +from mlx_engine.cache import make_prompt_cache from mlx_lm.models.cache import ( - make_prompt_cache, trim_prompt_cache, can_trim_prompt_cache, + RotatingKVCache, ) from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache import mlx.core as mx @@ -26,6 +27,7 @@ def __init__( kv_bits: Optional[int] = None, kv_group_size: Optional[int] = None, quantized_kv_start: Optional[int] = None, + keep: int = 4, ): """ Initialize the CacheWrapper. @@ -36,7 +38,9 @@ def __init__( """ # utilize a simple ordered list of tokens processed so far for cache invalidation checking self.tokens: Optional[mx.array] = None - self.cache: List[Any] = make_prompt_cache(model, max_kv_size) + self.keep = keep + self.cache: List[Any] = make_prompt_cache(model, max_kv_size, keep) + self.is_rotating = all(isinstance(c, RotatingKVCache) for c in self.cache) self.model = model self.draft_model: Optional[nn.Module] = None self.max_kv_size = max_kv_size @@ -115,7 +119,11 @@ def _get_unprocessed_tokens( message=f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: " "Cache is not trimmable. Clearing the cache instead.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.cache = make_prompt_cache( + self.model, + max_kv_size=self.max_kv_size if self.is_rotating else None, + keep=self.keep, + ) self.tokens = prompt_tokens return self.tokens tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) @@ -126,7 +134,11 @@ def _get_unprocessed_tokens( message=f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected " " ({num_tokens_to_trim}). Clearing the cache.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.cache = make_prompt_cache( + self.model, + max_kv_size=self.max_kv_size if self.is_rotating else None, + keep=self.keep, + ) self.tokens = prompt_tokens return self.tokens log_info( @@ -221,9 +233,11 @@ def set_draft_model(self, draft_model: nn.Module): message="Clearing current prompt cache and adding draft model to the cache", ) self.tokens = None - self.cache: List[Any] = make_prompt_cache(self.model) + self.cache: List[Any] = make_prompt_cache(self.model, keep=self.keep) + # the above will never return a rotating cache since there is no max_kv_size set + self.is_rotating = False if draft_model is not None: - self.cache += make_prompt_cache(draft_model) + self.cache += make_prompt_cache(draft_model, keep=self.keep) self.draft_model = draft_model def unset_draft_model(self): @@ -239,6 +253,7 @@ def update_cache( prompt_progress_callback, *, num_tokens_to_exclude: int = 1, + keep: int = 4, ) -> mx.array: """ Set up the KV cache for the next generation. @@ -248,6 +263,7 @@ def update_cache( prompt_tokens (mx.array): The prompt tokens. prompt_progress_callback (Callable): A callback function to report prompt processing progress. num_tokens_to_exclude (int): The number of tokens that should not be added to the cache. + keep (int): The number of tokens to always keep in the prefix of the prompt cache. Returns: mx.array: The prompt tokens to be used for the next generation. @@ -257,6 +273,12 @@ def update_cache( def prompt_progress_callback(x): return None + # update keep tracking + self.keep = keep + for cache in self.cache: + if hasattr(cache, "set_keep"): + cache.set_keep(keep) + num_tokens_to_exclude = max(num_tokens_to_exclude, 1) prompt_tokens = self._get_unprocessed_tokens( prompt_tokens, num_tokens_to_exclude @@ -296,5 +318,20 @@ def prompt_progress_callback(x): def record_generated_token(self, token): """ Add the generated token to the token list, so that we can map the token to the KV cache. + + Also loop when the cache does so that we accurately track what's in cache, if we're using + a RotatingKVCache or subclass of such. See the rotation implemented by MLX-LM here: + https://github.com/ml-explore/mlx-lm/blob/fd9b1909636d634ac2b848248b05939c9fbfbe19/mlx_lm/models/cache.py#L371 """ + # this behavior is common to rolling window (n_keep = 0) and truncate middle + # (n_keep > 0), and we should never get here with stop at max + if ( + self.max_kv_size is not None + and self.is_rotating + and len(self.tokens) >= self.max_kv_size + ): + # rotate the token tracking buffer + self.tokens = mx.concat( + [self.tokens[: self.keep], self.tokens[self.keep + 1 :]] + ) self.tokens = mx.concat([self.tokens, mx.array([token])]) diff --git a/mlx_engine/generate.py b/mlx_engine/generate.py index 6019cf1..1a80eec 100644 --- a/mlx_engine/generate.py +++ b/mlx_engine/generate.py @@ -137,6 +137,7 @@ def create_generator( max_tokens: Optional[int] = 10000000, speculative_decoding_toggle: Optional[bool] = None, num_draft_tokens: Optional[int] = None, + keep: int = 4, ) -> Iterator[GenerationResult]: """ Create a generator that streams text generation results from the model. @@ -218,6 +219,7 @@ def create_generator( prompt_progress_callback, generate_args, speculative_decoding_toggle, + keep=keep, ) if draft_model is None: # input embeddings not yet supported for speculative decoding in mlx-lm diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index a389dc6..4f93437 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -141,6 +141,7 @@ def process_prompt( prompt_progress_callback, generate_args, speculative_decoding_toggle: Optional[bool] = None, + keep: int = 4, ) -> Tuple[mx.array, Optional[mx.array]]: ### TEXT-ONLY PROCESS_PROMPT ### is_text_only_processing = images_b64 is None or len(images_b64) == 0 @@ -160,6 +161,7 @@ def process_prompt( self.draft_model, speculative_decoding_toggle, prompt_progress_callback, + keep=keep, ), None ### WITH IMAGES PROMPT PROCESSING ###s if self.vision_add_on is None: diff --git a/mlx_engine/utils/prompt_processing.py b/mlx_engine/utils/prompt_processing.py index 78a687d..380cf24 100644 --- a/mlx_engine/utils/prompt_processing.py +++ b/mlx_engine/utils/prompt_processing.py @@ -13,6 +13,7 @@ def process_prompt_text_only( draft_model: Optional[nn.Module] = None, speculative_decoding_toggle: Optional[bool] = None, prompt_progress_callback: Optional[Callable[[float], None]] = None, + keep: int = 4, ): if cache_wrapper is None: raise ValueError("Cache wrapper is not initialized, cannot process prompt") @@ -38,6 +39,7 @@ def process_prompt_text_only( prompt_tokens = cache_wrapper.update_cache( prompt_tokens, prompt_progress_callback, + keep=keep, ) generate_args["prompt_cache"] = cache_wrapper.cache return prompt_tokens diff --git a/tests/test_cache_generic.py b/tests/test_cache_generic.py new file mode 100644 index 0000000..bb5c905 --- /dev/null +++ b/tests/test_cache_generic.py @@ -0,0 +1,38 @@ +import unittest +import mlx.core as mx +from copy import deepcopy +from mlx_engine.cache import AlwaysTrimmableKVCache + + +class TestCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Set up test resources that will be shared across all test methods""" + cls.kv_head_dim = 4 + cls.bsz = 1 + cls.n_kv_heads = 1 + + @classmethod + def make_random_kv(cls, seqlen: int): + """Helper method to make a random key/value tensor of the right shape""" + return mx.random.normal( + (cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), + scale=1.0, + dtype=mx.float32, + ) + + def assertArrEqual(self, a: mx.array, b: mx.array): + """Assert that two tensors are equal over the sequence length dimension""" + self.assertEqual(a.shape, b.shape) + self.assertTrue(mx.allclose(a, b), "Tensors are not equal") + + def add_random_to_cache( + self, cache: AlwaysTrimmableKVCache, seqlen: int + ) -> mx.array: + """Add random values to the cache and return them""" + base_kv = self.make_random_kv(seqlen) + # base_kv is *assigned* to cache.keys/cache.values so returning base_kv + # would return a reference to cache.keys, which is pointless. so copy it + reference = deepcopy(base_kv) + cache.update_and_fetch(base_kv, base_kv) + return reference diff --git a/tests/test_cache_trim.py b/tests/test_cache_trim.py new file mode 100644 index 0000000..ae0301f --- /dev/null +++ b/tests/test_cache_trim.py @@ -0,0 +1,180 @@ +import unittest +import mlx.core as mx +from mlx_engine.cache import AlwaysTrimmableKVCache +from tests.test_cache_generic import TestCache + + +def idx(v: mx.array, i: int): + """Helper function to index into a 4D tensor at the sequence length dimension""" + return v[:, :, i : i + 1, :] + + +class TestAlwaysTrimmableKVCache(TestCache): + def test_overwriting(self): + """Test overwriting when the cache reaches max_size""" + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) + + # fill cache -> 123 + reference = self.add_random_to_cache(cache, 3) + self.assertEqual(cache.offset, 3) + + # attempt to write another element 4 -> 143 + overwrite = self.add_random_to_cache(cache, 1) + # access k/v as cache.state[0]/[1] due to possibly empty buffer + keys = cache.state[0] + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), overwrite) + self.assertArrEqual(idx(keys, 2), idx(reference, 2)) + self.assertEqual(cache.offset, 4) + + def test_ensure_update_increases_offset_indefinitely(self): + """Test single-token updates that should increase offset""" + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) + + for i in range(10): + self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset - 1, i) + + def test_ensure_reasonable_size_and_shift(self): + """Test behavior when the cache gets a KV batch-written that is much larger + than max_size. The default behavior of the cache is to write the entire thing, + then trim it back down when the next KV is written. + """ + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) + + # fill cache -> 0123456789 + reference = self.add_random_to_cache(cache, 10) + keys = cache.state[0] + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 10, self.kv_head_dim)) + self.assertEqual(cache.offset, 10) + + # trigger trim -> 0X9 -> (rope) 021 + overwrite = self.add_random_to_cache(cache, 1) + keys = cache.state[0] + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 3, self.kv_head_dim)) + self.assertEqual(cache.offset, 11) + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), overwrite) + self.assertArrEqual(idx(keys, 2), idx(reference, 9)) + + # make sure pos embs are right + cache.keys = cache._temporal_order(cache.keys) + cache.values = cache._temporal_order(cache.values) + keys = cache.state[0] + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), idx(reference, 9)) + self.assertArrEqual(idx(keys, 2), overwrite) + self.assertEqual(cache.offset, 11) + + # ensure offset keeps increasing + self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 12) + + self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 13) + + def test_update_keep_on_the_fly(self): + """Test changing the keep value on the fly""" + cache = AlwaysTrimmableKVCache(max_size=4, keep=1) + + # fill cache -> 1234 + reference = self.add_random_to_cache(cache, 4) + + # attempt to write another element 5 -> 1534 + overwrite = self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 5) + + # update keep -> 1345 -> 1234 implicitly + # and attempt to write another element 5 -> 1254 + # offset updates after set_keep (anytime we reorder/rope shift) + cache.set_keep(2) + self.assertEqual(cache.offset, 5) + overwrite2 = self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 6) + keys = cache.state[0] + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), idx(reference, 2)) + self.assertArrEqual(idx(keys, 2), overwrite2) + self.assertArrEqual(idx(keys, 3), overwrite) + + def test_trim_before_full(self): + """Test trimming from the end before the cache is full""" + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) + + # fill cache -> 12 + reference = self.add_random_to_cache(cache, 2) + + # trim 1 from end -> 1 + cache.trim(1) + keys = cache.state[0] + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 1, self.kv_head_dim)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertEqual(cache.offset, 1) + + # ensure adding another value works fine + new_kv = self.add_random_to_cache(cache, 1) + keys = cache.state[0] + self.assertEqual(cache.offset, 2) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), new_kv) + self.assertEqual(cache.offset, 2) + + def test_trim_after_overwrite(self): + """Test trimming from the end when we've written past the cache max""" + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) + + # fill cache -> 123 + reference = self.add_random_to_cache(cache, 3) + self.assertEqual(cache.offset, 3) + + # overwrite so offset goes over max_size -> 143 + base_kv = self.make_random_kv(1) + cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 4) + + # trim 1 from end -> 13 -> 12 (rope), ideally + cache.trim(1) + keys = cache.state[0] + + should_be_kv = mx.concatenate( + [reference[:, :, :1, :], reference[:, :, 2:3, :]], axis=2 + ) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(keys, should_be_kv) + self.assertEqual(cache.offset, 2) + + def test_trim_after_full(self): + """Test trimming from the end when the cache is oversize""" + cache = AlwaysTrimmableKVCache(max_size=3, keep=1) + + # fill cache oversize already -> 1234 + reference = self.add_random_to_cache(cache, 4) + self.assertEqual(cache.offset, 4) + + # trim 2 from end -> 12 + cache.trim(2) + keys = cache.state[0] + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(keys, reference[:, :, :2, :]) + self.assertEqual(cache.offset, 2) + + # ensure adding more values works fine + new_kv = self.add_random_to_cache(cache, 2) + keys = cache.state[0] + self.assertEqual(cache.offset, 4) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 4, self.kv_head_dim)) + self.assertArrEqual(keys[:, :, :2, :], reference[:, :, :2, :]) + self.assertArrEqual(keys[:, :, 2:, :], new_kv) + + +if __name__ == "__main__": + unittest.main(verbosity=2, failfast=True) diff --git a/tests/test_text_models.py b/tests/test_text_models.py index 67dba02..1f4dd94 100644 --- a/tests/test_text_models.py +++ b/tests/test_text_models.py @@ -208,6 +208,79 @@ def generate(text_accumulator: list) -> None: ) self.assertEqual(generated_text_1, generated_text_2) + def test_cache_nuke_qwen2_5(self): + model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit") + model_kit = load_model(model_path=model_path, max_kv_size=32) + prompt = """<|im_start|>user +Explain how the universe works. What was the Big Bang? What's redshifting? +<|im_end|> +<|im_start|>assistant +""" + prompt_tokens = tokenize(model_kit, prompt) + log_info( + prefix="test_cache_nuke", + message=f"Generation 1 number of prompt tokens: {len(prompt_tokens)}", + ) + generated_text_list_1 = [] + prompt_progress_callback_times_called = 0 + + def prompt_progress_callback(progress: float) -> None: + nonlocal prompt_progress_callback_times_called + prompt_progress_callback_times_called += 1 + print(f"Prompt Progress: {progress:.2f}") + + # accumulating to list allows pass by reference + def generate(text_accumulator: list) -> None: + for result in create_generator( + model_kit=model_kit, + prompt_tokens=prompt_tokens, + seed=0, + max_tokens=100, + temp=0.0, + prompt_progress_callback=prompt_progress_callback, + ): + print(result.text, end="", flush=True) + text_accumulator.append(result.text) + if result.stop_condition: + break + print("\n", flush=True) + + ### Generation 1 - fills cache + generate(text_accumulator=generated_text_list_1) + generated_text_1 = "".join(generated_text_list_1) + self.assertEqual(prompt_progress_callback_times_called, 2) + self.assertGreater( + len(generated_text_1), 0, "Model failed to generate any text" + ) + gen1_cache_layer0 = model_kit.cache_wrapper.cache[0] + + ### Generation 2 - trims cache + prompt = """<|im_start|>user +Explain how the universe works. What was the Big Bang? +<|im_end|> +<|im_start|>assistant +""" + prompt_tokens = tokenize(model_kit, prompt) + log_info( + prefix="test_cache_nuke", + message=f"Generation 2 number of prompt tokens: {len(prompt_tokens)}", + ) + generated_text_list_2 = [] + prompt_progress_callback_times_called = 0 + generate(text_accumulator=generated_text_list_2) + generated_text_2 = "".join(generated_text_list_2) + # Expect prompt cache to be intact for the first half of the file_content, so we should get 1 + # intermediate callback this time + self.assertEqual(prompt_progress_callback_times_called, 2) + self.assertGreater( + len(generated_text_2), 0, "Model failed to generate any text" + ) + gen2_cache_layer0 = model_kit.cache_wrapper.cache[0] + + # if we nuked cache, these will reference different locations in memory + # if we didn't, they'll refer to the same object + self.assertTrue(gen1_cache_layer0 is gen2_cache_layer0) + class TestStructuredGen(unittest.TestCase): def setUp(self):