diff --git a/mlx_engine/cache.py b/mlx_engine/cache.py new file mode 100644 index 0000000..9f7088a --- /dev/null +++ b/mlx_engine/cache.py @@ -0,0 +1,107 @@ +from typing import List, Optional, Any + +from mlx_lm.models.cache import RotatingKVCache, KVCache +import mlx.core as mx +import mlx.nn as nn + + +class ShiftingKVCache(RotatingKVCache): + def __init__(self, max_size=256, keep=0, step=256): + self.reuse_queue = [] + super().__init__(max_size=max_size, keep=keep, step=step) + + def reuse_section( + self, write_start_idx: int, reuse_start_idx: int, reuse_length: int + ) -> None: + # queue for reuse: everything is done in one pass at the end in do_reuse + self.reuse_queue.append((write_start_idx, reuse_start_idx, reuse_length)) + + def do_reuse(self) -> None: + if not self.reuse_queue: + return + + # just in case maybe + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + + # just in case, sort in write order + self.reuse_queue.sort(key=lambda x: x[0]) + + key_segments = [] + value_segments = [] + current_pos = 0 + + for write_start_idx, reuse_start_idx, reuse_length in self.reuse_queue: + # add any gap before this write position + if current_pos < write_start_idx: + key_segments.append(self.keys[..., current_pos:write_start_idx, :]) + value_segments.append(self.values[..., current_pos:write_start_idx, :]) + + reuse_end_idx = reuse_start_idx + reuse_length + current_pos = write_start_idx + reuse_length + + key_segments.append(self.keys[..., reuse_start_idx:reuse_end_idx, :]) + value_segments.append(self.values[..., reuse_start_idx:reuse_end_idx, :]) + + self.keys = mx.concatenate(key_segments, axis=2) + self.values = mx.concatenate(value_segments, axis=2) + + # clean up + self.reuse_queue = [] + self._idx = self.keys.shape[2] + self.offset = self.keys.shape[2] + + def trim(self, n) -> int: + # trim must not respect keep + n = min(self.offset, n) + if n <= 0: + return 0 + + # put us back into the state before the circular buffer is full + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + + new_length = max(self.keys.shape[2] - n, 0) + self.keys = self.keys[..., :new_length, :] + self.values = self.values[..., :new_length, :] + + self.offset = new_length + self._idx = new_length + return n + + def set_keep(self, keep): + # kv must be in temporal order, else we will keep the wrong thing + if self.keys is not None: + self.keys = self._temporal_order(self.keys) + if self.values is not None: + self.values = self._temporal_order(self.values) + self.keep = keep + + def is_trimmable(self) -> bool: + return True + + +def make_prompt_cache( + model: nn.Module, + max_kv_size: Optional[int] = None, + keep: int = 4, +) -> List[Any]: + """ + Construct the model's cache for use in generation. + This function will defer the cache construction to the model if it has a + ``make_cache`` method, otherwise it will make a default KV cache. + Args: + model (nn.Module): The language model. + max_kv_size (Optional[int]): If provided and the model does not have a + ``make_cache`` method, a ``ShiftingKVCache`` is used with a maximum + size of ``max_kv_size`` + """ + if hasattr(model, "make_cache"): + return model.make_cache() + num_layers = len(model.layers) + if max_kv_size is not None: + return [ + ShiftingKVCache(max_size=max_kv_size, keep=keep) for _ in range(num_layers) + ] + else: + return [KVCache() for _ in range(num_layers)] diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 36498fa..7cbd5c3 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -1,11 +1,8 @@ from typing import List, Optional, Any from mlx_engine.logging import log_info, log_warn, log_error -from mlx_lm.models.cache import ( - make_prompt_cache, - trim_prompt_cache, - can_trim_prompt_cache, -) +from mlx_engine.cache import make_prompt_cache +from mlx_lm.models.cache import trim_prompt_cache, can_trim_prompt_cache from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache import mlx.core as mx import mlx.nn as nn @@ -26,6 +23,7 @@ def __init__( kv_bits: Optional[int] = None, kv_group_size: Optional[int] = None, quantized_kv_start: Optional[int] = None, + keep: int = 4, ): """ Initialize the CacheWrapper. @@ -36,7 +34,8 @@ def __init__( """ # utilize a simple ordered list of tokens processed so far for cache invalidation checking self.tokens: Optional[mx.array] = None - self.cache: List[Any] = make_prompt_cache(model, max_kv_size) + self.keep = keep + self.cache: List[Any] = make_prompt_cache(model, max_kv_size, keep) self.model = model self.draft_model: Optional[nn.Module] = None self.max_kv_size = max_kv_size @@ -48,41 +47,87 @@ def __init__( ) @staticmethod - def _find_common_prefix( - current_tokens: mx.array, prompt_tokens: mx.array, num_tokens_to_exclude: int + def _find_matching_sequence_length( + tokens1: mx.array, + tokens2: mx.array, + start1: int = 0, + start2: int = 0, ) -> int: """ - Determine the common prefix length between the current tokens and the prompt tokens. + Find the length of matching token sequence between two token arrays. Args: - current_tokens (mx.array): The cached tokens (self.tokens). - prompt_tokens (mx.array): The prompt tokens. - num_tokens_to_exclude (int): The minimum length of the remaining prompt tokens array. + tokens1: First token array + start1: Starting position in first array + tokens2: Second token array + start2: Starting position in second array Returns: - int: The length of the common prefix. + int: Length of matching sequence """ - prompt_tokens = prompt_tokens - current_tokens = current_tokens - # Find the minimum length between the two arrays - min_length = min(len(current_tokens), len(prompt_tokens)) - - # Compare elements up to the minimum length - mask = prompt_tokens[:min_length] == current_tokens[:min_length] - - # Find the index where the first mismatch occurs - if mx.any(mask == False): # noqa E712 - common_length = int(mx.argmax(mask == False)) # noqa E712 - else: - common_length = int(min_length) - - # Ensure that the prompt is at least num_tokens_to_exclude long - uncached_prompt_tokens_length = len(prompt_tokens[common_length:]) - length_adjustment = max( - 0, num_tokens_to_exclude - uncached_prompt_tokens_length - ) - common_length = max(common_length - length_adjustment, 0) - return common_length + # Calculate actual bounds + max_len1 = len(tokens1) - start1 + max_len2 = len(tokens2) - start2 + min_length = int(min(max_len1, max_len2)) + + # Extract subsequences to compare + seq1 = tokens1[start1 : start1 + min_length] + seq2 = tokens2[start2 : start2 + min_length] + + # Find first mismatch + mask = seq1 == seq2 + return int(mx.argmax(mask == False)) if mx.any(mask == False) else min_length # noqa E712 + + def _truncate_cache( + self, + prompt_tokens: mx.array, + common_prefix_len: int, + non_prefix_reuse_min_seq_len: int = 256, + ) -> int: + cache_size = len(self.tokens) + prompt_size = len(prompt_tokens) + + # start scanning from after the common prefix + cache_head_idx = common_prefix_len + prompt_head_idx = common_prefix_len + total_reused = 0 + + if self.verbose: + print( + f"Looking for non-prefix sequences of length >= {non_prefix_reuse_min_seq_len}", + file=sys.stderr, + ) + + while cache_head_idx < cache_size and prompt_head_idx < prompt_size: + match_length = self._find_matching_sequence_length( + prompt_tokens, self.tokens, prompt_head_idx, cache_head_idx + ) + + if match_length < non_prefix_reuse_min_seq_len: + # sequence too short - advance cache pointer to find next potential match + cache_head_idx += 1 + else: + if self.verbose: + print(f"Reusing {match_length} tokens from cache", file=sys.stderr) + + # found reusable sequence - shift cache content + for cache in self.cache: + cache.reuse_section(prompt_head_idx, cache_head_idx, match_length) + + # update the tokens to reflect the reused sequence + for i in range(match_length): + self.tokens[prompt_head_idx + i] = self.tokens[cache_head_idx + i] + + # advance pointers + cache_head_idx += match_length + prompt_head_idx += match_length + total_reused += match_length + + for cache in self.cache: + cache.do_reuse() + self.tokens = self.tokens[: common_prefix_len + total_reused] + + return total_reused def _get_unprocessed_tokens( self, prompt_tokens: mx.array, num_tokens_to_exclude: int @@ -102,12 +147,30 @@ def _get_unprocessed_tokens( return self.tokens # Find common KV between the last generation and the current prompt - common_prefix = self._find_common_prefix( - self.tokens, prompt_tokens, num_tokens_to_exclude + common_prefix = self._find_matching_sequence_length( + self.tokens, + prompt_tokens, ) + # do reuse but only if the cache has it + if hasattr(self.cache[0], "reuse_section"): + n_reused_tokens = self._truncate_cache( + prompt_tokens, + common_prefix, + ) + log_info( + prefix="CacheWrapper", + message=f"Reused {n_reused_tokens} tokens from the cache", + ) + common_prefix += n_reused_tokens + + # exclude some tokens from end, e.g. for kicking off generation + if common_prefix >= len(prompt_tokens) - num_tokens_to_exclude: + common_prefix = len(prompt_tokens) - num_tokens_to_exclude + # Trim the cache if the common prefix is shorter than the current cache - num_tokens_to_trim = self.cache[0].offset - common_prefix + # state[0] is an alias for keys that accounts for partially filled buffers + num_tokens_to_trim = self.cache[0].state[0].shape[2] - common_prefix if num_tokens_to_trim > 0: if not can_trim_prompt_cache(self.cache): log_warn( @@ -115,7 +178,9 @@ def _get_unprocessed_tokens( message=f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: " "Cache is not trimmable. Clearing the cache instead.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.cache = make_prompt_cache( + self.model, self.max_kv_size, keep=self.keep + ) self.tokens = prompt_tokens return self.tokens tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) @@ -126,7 +191,9 @@ def _get_unprocessed_tokens( message=f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected " " ({num_tokens_to_trim}). Clearing the cache.", ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.cache = make_prompt_cache( + self.model, self.max_kv_size, keep=self.keep + ) self.tokens = prompt_tokens return self.tokens log_info( @@ -221,9 +288,9 @@ def set_draft_model(self, draft_model: nn.Module): message="Clearing current prompt cache and adding draft model to the cache", ) self.tokens = None - self.cache: List[Any] = make_prompt_cache(self.model) + self.cache: List[Any] = make_prompt_cache(self.model, keep=self.keep) if draft_model is not None: - self.cache += make_prompt_cache(draft_model) + self.cache += make_prompt_cache(draft_model, keep=self.keep) self.draft_model = draft_model def unset_draft_model(self): @@ -239,6 +306,7 @@ def update_cache( prompt_progress_callback, *, num_tokens_to_exclude: int = 1, + keep: int = 4, ) -> mx.array: """ Set up the KV cache for the next generation. @@ -248,6 +316,7 @@ def update_cache( prompt_tokens (mx.array): The prompt tokens. prompt_progress_callback (Callable): A callback function to report prompt processing progress. num_tokens_to_exclude (int): The number of tokens that should not be added to the cache. + keep (int): The number of tokens to always keep in the prefix of the prompt cache. Returns: mx.array: The prompt tokens to be used for the next generation. @@ -257,6 +326,12 @@ def update_cache( def prompt_progress_callback(x): return None + # update keep tracking + self.keep = keep + for cache in self.cache: + if hasattr(cache, "set_keep"): + cache.set_keep(keep) + num_tokens_to_exclude = max(num_tokens_to_exclude, 1) prompt_tokens = self._get_unprocessed_tokens( prompt_tokens, num_tokens_to_exclude @@ -296,5 +371,13 @@ def prompt_progress_callback(x): def record_generated_token(self, token): """ Add the generated token to the token list, so that we can map the token to the KV cache. + + Also loop when the cache does so that we accurately track what's in cache. """ + # this behavior is common to rolling window (n_keep = 0) and truncate middle + # (n_keep > 0), and we should never get here with stop at max + if len(self.tokens) >= self.max_kv_size: + self.tokens = mx.concat( + [self.tokens[: self.keep], self.tokens[self.keep + 1 :]] + ) self.tokens = mx.concat([self.tokens, mx.array([token])]) diff --git a/mlx_engine/generate.py b/mlx_engine/generate.py index 6019cf1..54c581a 100644 --- a/mlx_engine/generate.py +++ b/mlx_engine/generate.py @@ -137,6 +137,7 @@ def create_generator( max_tokens: Optional[int] = 10000000, speculative_decoding_toggle: Optional[bool] = None, num_draft_tokens: Optional[int] = None, + keep: Optional[int] = 4, ) -> Iterator[GenerationResult]: """ Create a generator that streams text generation results from the model. @@ -171,6 +172,8 @@ def create_generator( if a draft model is loaded. If set to true, draft model must be loaded or else error. If set to false, speculative decoding is disabled even if a draft model is loaded. num_draft_tokens (Optional[int]): Number of tokens to draft when using speculative decoding + keep (Optional[int]): Number of tokens to always keep in the prefix of the prompt cache. + Defaults to 4, which is the minimum number of tokens needed for a valid prompt. Yields: GenerationResult: A named tuple containing: @@ -218,6 +221,7 @@ def create_generator( prompt_progress_callback, generate_args, speculative_decoding_toggle, + keep=keep, ) if draft_model is None: # input embeddings not yet supported for speculative decoding in mlx-lm diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index a389dc6..4f93437 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -141,6 +141,7 @@ def process_prompt( prompt_progress_callback, generate_args, speculative_decoding_toggle: Optional[bool] = None, + keep: int = 4, ) -> Tuple[mx.array, Optional[mx.array]]: ### TEXT-ONLY PROCESS_PROMPT ### is_text_only_processing = images_b64 is None or len(images_b64) == 0 @@ -160,6 +161,7 @@ def process_prompt( self.draft_model, speculative_decoding_toggle, prompt_progress_callback, + keep=keep, ), None ### WITH IMAGES PROMPT PROCESSING ###s if self.vision_add_on is None: diff --git a/mlx_engine/utils/prompt_processing.py b/mlx_engine/utils/prompt_processing.py index 78a687d..380cf24 100644 --- a/mlx_engine/utils/prompt_processing.py +++ b/mlx_engine/utils/prompt_processing.py @@ -13,6 +13,7 @@ def process_prompt_text_only( draft_model: Optional[nn.Module] = None, speculative_decoding_toggle: Optional[bool] = None, prompt_progress_callback: Optional[Callable[[float], None]] = None, + keep: int = 4, ): if cache_wrapper is None: raise ValueError("Cache wrapper is not initialized, cannot process prompt") @@ -38,6 +39,7 @@ def process_prompt_text_only( prompt_tokens = cache_wrapper.update_cache( prompt_tokens, prompt_progress_callback, + keep=keep, ) generate_args["prompt_cache"] = cache_wrapper.cache return prompt_tokens diff --git a/tests/test_cache_generic.py b/tests/test_cache_generic.py new file mode 100644 index 0000000..4a7b21f --- /dev/null +++ b/tests/test_cache_generic.py @@ -0,0 +1,36 @@ +import unittest +import mlx.core as mx +from copy import deepcopy +from mlx_engine.cache import ShiftingKVCache + + +class TestCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Set up test resources that will be shared across all test methods""" + cls.kv_head_dim = 4 + cls.bsz = 1 + cls.n_kv_heads = 1 + + @classmethod + def make_random_kv(cls, seqlen: int): + """Helper method to make a random key/value tensor of the right shape""" + return mx.random.normal( + (cls.bsz, cls.n_kv_heads, seqlen, cls.kv_head_dim), + scale=1.0, + dtype=mx.float32, + ) + + def assertArrEqual(self, a: mx.array, b: mx.array): + """Assert that two tensors are equal over the sequence length dimension""" + self.assertEqual(a.shape, b.shape) + self.assertTrue(mx.allclose(a, b), "Tensors are not equal") + + def add_random_to_cache(self, cache: ShiftingKVCache, seqlen: int) -> mx.array: + """Add random values to the cache and return them""" + base_kv = self.make_random_kv(seqlen) + # base_kv is *assigned* to cache.keys/cache.values so returning base_kv + # would return a reference to cache.keys, which is pointless. so copy it + reference = deepcopy(base_kv) + cache.update_and_fetch(base_kv, base_kv) + return reference diff --git a/tests/test_cache_shift.py b/tests/test_cache_shift.py new file mode 100644 index 0000000..0ae0100 --- /dev/null +++ b/tests/test_cache_shift.py @@ -0,0 +1,229 @@ +import unittest +import mlx.core as mx +from mlx_engine.cache import ShiftingKVCache +from tests.test_cache_generic import TestCache + + +def idx(v: mx.array, i: int): + """Helper function to index into a 4D tensor at the sequence length dimension""" + return v[:, :, i : i + 1, :] + + +class TestShiftingKVCache(TestCache): + def test_overwriting(self): + """Test overwriting when the cache reaches max_size""" + cache = ShiftingKVCache(max_size=3, keep=1) + + # fill cache -> 123 + reference = self.add_random_to_cache(cache, 3) + self.assertEqual(cache.offset, 3) + + # attempt to write another element 4 -> 143 + overwrite = self.add_random_to_cache(cache, 1) + # access k/v as cache.state[0]/[1] due to possibly empty buffer + keys = cache.state[0] + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), overwrite) + self.assertArrEqual(idx(keys, 2), idx(reference, 2)) + self.assertEqual(cache.offset, 4) + + def test_ensure_update_increases_offset_indefinitely(self): + """Test single-token updates that should increase offset""" + cache = ShiftingKVCache(max_size=3, keep=1) + + for i in range(10): + self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset - 1, i) + + def test_ensure_reasonable_size_and_shift(self): + """Test behavior when the cache gets a KV batch-written that is much larger + than max_size. The default behavior of the cache is to write the entire thing, + then trim it back down when the next KV is written. + """ + cache = ShiftingKVCache(max_size=3, keep=1) + + # fill cache -> 0123456789 + reference = self.add_random_to_cache(cache, 10) + keys = cache.state[0] + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 10, self.kv_head_dim)) + self.assertEqual(cache.offset, 10) + + # trigger trim -> 0X9 -> (rope) 021 + overwrite = self.add_random_to_cache(cache, 1) + keys = cache.state[0] + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 3, self.kv_head_dim)) + self.assertEqual(cache.offset, 11) + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), overwrite) + self.assertArrEqual(idx(keys, 2), idx(reference, 9)) + + # make sure pos embs are right + cache.keys = cache._temporal_order(cache.keys) + cache.values = cache._temporal_order(cache.values) + keys = cache.state[0] + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), idx(reference, 9)) + self.assertArrEqual(idx(keys, 2), overwrite) + self.assertEqual(cache.offset, 11) + + # ensure offset keeps increasing + self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 12) + + self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 13) + + def test_update_keep_on_the_fly(self): + """Test changing the keep value on the fly""" + cache = ShiftingKVCache(max_size=4, keep=1) + + # fill cache -> 1234 + reference = self.add_random_to_cache(cache, 4) + + # attempt to write another element 5 -> 1534 + overwrite = self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 5) + + # update keep -> 1345 -> 1234 implicitly + # and attempt to write another element 5 -> 1254 + # offset updates after set_keep (anytime we reorder/rope shift) + cache.set_keep(2) + self.assertEqual(cache.offset, 5) + overwrite2 = self.add_random_to_cache(cache, 1) + self.assertEqual(cache.offset, 6) + keys = cache.state[0] + + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), idx(reference, 2)) + self.assertArrEqual(idx(keys, 2), overwrite2) + self.assertArrEqual(idx(keys, 3), overwrite) + + def test_trim_before_full(self): + """Test trimming from the end before the cache is full""" + cache = ShiftingKVCache(max_size=3, keep=1) + + # fill cache -> 12 + reference = self.add_random_to_cache(cache, 2) + + # trim 1 from end -> 1 + cache.trim(1) + keys = cache.state[0] + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 1, self.kv_head_dim)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertEqual(cache.offset, 1) + + # ensure adding another value works fine + new_kv = self.add_random_to_cache(cache, 1) + keys = cache.state[0] + self.assertEqual(cache.offset, 2) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(idx(keys, 0), idx(reference, 0)) + self.assertArrEqual(idx(keys, 1), new_kv) + self.assertEqual(cache.offset, 2) + + def test_trim_after_overwrite(self): + """Test trimming from the end when we've written past the cache max""" + cache = ShiftingKVCache(max_size=3, keep=1) + + # fill cache -> 123 + reference = self.add_random_to_cache(cache, 3) + self.assertEqual(cache.offset, 3) + + # overwrite so offset goes over max_size -> 143 + base_kv = self.make_random_kv(1) + cache.update_and_fetch(base_kv, base_kv) + self.assertEqual(cache.offset, 4) + + # trim 1 from end -> 13 -> 12 (rope), ideally + cache.trim(1) + keys = cache.state[0] + + should_be_kv = mx.concatenate( + [reference[:, :, :1, :], reference[:, :, 2:3, :]], axis=2 + ) + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(keys, should_be_kv) + self.assertEqual(cache.offset, 2) + + def test_trim_after_full(self): + """Test trimming from the end when the cache is oversize""" + cache = ShiftingKVCache(max_size=3, keep=1) + + # fill cache oversize already -> 1234 + reference = self.add_random_to_cache(cache, 4) + self.assertEqual(cache.offset, 4) + + # trim 2 from end -> 12 + cache.trim(2) + keys = cache.state[0] + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 2, self.kv_head_dim)) + self.assertArrEqual(keys, reference[:, :, :2, :]) + self.assertEqual(cache.offset, 2) + + # ensure adding more values works fine + new_kv = self.add_random_to_cache(cache, 2) + keys = cache.state[0] + self.assertEqual(cache.offset, 4) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 4, self.kv_head_dim)) + self.assertArrEqual(keys[:, :, :2, :], reference[:, :, :2, :]) + self.assertArrEqual(keys[:, :, 2:, :], new_kv) + + def test_reuse(self): + """Test basic reuse APIs""" + cache = ShiftingKVCache(max_size=8, keep=1) + + # fill cache -> 12345678 + reference = self.add_random_to_cache(cache, 8) + + # reuse a specific section (hardcoded), dynamic reuse is in test_cache_wrapper + cache.reuse_section(3, 4, 2) + cache.do_reuse() + keys = cache.state[0] + + # this is what the remaining cache should look like + should_be_keys = mx.concatenate( + [reference[:, :, :3, :], reference[:, :, 4:6, :]], axis=2 + ) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) + self.assertArrEqual(keys, should_be_keys) + self.assertEqual(cache.offset, 5) + + def test_reuse_after_overwrite(self): + """Test basic reuse APIs after an overwrite""" + cache = ShiftingKVCache(max_size=8, keep=1) + + # fill cache -> 12345678 + reference = self.add_random_to_cache(cache, 8) + news = self.add_random_to_cache(cache, 1) # overwrite to 13456789 after TO + self.assertArrEqual( + cache.state[0], mx.concatenate( + [reference[:, :, :1, :], news, reference[:, :, 2:8, :]], axis=2 + ) + ) + + # suppose the prompt coming in is now 13678 + # reuse from 2 to 4 length 3 + cache.reuse_section(2, 4, 3) + cache.do_reuse() + keys = cache.state[0] + + # the remaining cache should be 13678 + should_be_keys = mx.concatenate( + [reference[:, :, :1, :], reference[:,:,2:3,:], reference[:, :, 5:8, :]], axis=2 + ) + + self.assertEqual(keys.shape, (self.bsz, self.n_kv_heads, 5, self.kv_head_dim)) + self.assertArrEqual(keys, should_be_keys) + self.assertEqual(cache.offset, 5) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index d1d6456..61149a6 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -1,42 +1,308 @@ import unittest import mlx.core as mx from mlx_engine.cache_wrapper import CacheWrapper +from mlx_engine.cache import ShiftingKVCache +from tests.test_cache_generic import TestCache +from tests.utils import DummyModel, model_getter +from mlx_engine.generate import load_model, create_generator -class TestCacheWrapper(unittest.TestCase): - def test_find_common_prefix_with_mismatch(self): +class TestCacheWrapper(TestCache): + def test_find_matching_sequence_length_with_mismatch(self): """Test when there's a mismatch in the tokens""" # Create two arrays with a known common prefix [1, 2, 3] current_tokens = mx.array([1, 2, 3, 4, 5]) prompt_tokens = mx.array([1, 2, 3, 6, 7]) # Mismatch at index 3 - num_tokens_to_exclude = 1 print("\nTest with mismatch:") print(f"current_tokens: {current_tokens}") print(f"prompt_tokens: {prompt_tokens}") - result = CacheWrapper._find_common_prefix( - current_tokens, prompt_tokens, num_tokens_to_exclude + result = CacheWrapper._find_matching_sequence_length( + current_tokens, prompt_tokens ) self.assertEqual(result, 3) # Should find 3 matching tokens - def test_find_common_prefix_all_match(self): + def test_find_matching_sequence_length_all_match(self): """Test when all tokens match""" # Create two identical arrays current_tokens = mx.array([1, 2, 3, 4, 5]) prompt_tokens = mx.array([1, 2, 3, 4, 5]) # All tokens match - num_tokens_to_exclude = 1 print("\nTest with all matching:") print(f"current_tokens: {current_tokens}") print(f"prompt_tokens: {prompt_tokens}") - result = CacheWrapper._find_common_prefix( - current_tokens, prompt_tokens, num_tokens_to_exclude + result = CacheWrapper._find_matching_sequence_length( + current_tokens, prompt_tokens ) - self.assertEqual( - result, 4 - ) # Should find 4 matching tokens (5-1 due to num_tokens_to_exclude) + self.assertEqual(result, 5) # Should find 5 matching tokens + + def test_find_matching_sequence_length_no_match(self): + """Test when no tokens match""" + # Create two arrays with no common prefix + current_tokens = mx.array([1, 2, 3, 4, 5]) + prompt_tokens = mx.array([6, 7, 8, 9, 10]) + + print("\nTest with no matching tokens:") + print(f"current_tokens: {current_tokens}") + print(f"prompt_tokens: {prompt_tokens}") + + result = CacheWrapper._find_matching_sequence_length( + current_tokens, prompt_tokens + ) + self.assertEqual(result, 0) # No matching tokens should return 0 + + def test_find_matching_sequence_length_offset_starts(self): + """Test when the current tokens start with a different offset""" + # Create two arrays where the current tokens start with a different offset + current_tokens = mx.array([2, 3, 4, 5]) + prompt_tokens = mx.array([1, 2, 3, 4, 5]) + + print("\nTest with offset starts:") + print(f"current_tokens: {current_tokens}") + print(f"prompt_tokens: {prompt_tokens}") + + result = CacheWrapper._find_matching_sequence_length( + current_tokens, + prompt_tokens, + start2=1, + ) + self.assertEqual(result, 4) + + def test_find_matching_sequence_length_more_offsets(self): + """Test when the current tokens have more offsets""" + # Create two arrays where the current tokens have more offsets + current_tokens = mx.array([1, 2, 3, 4, 5, 6]) + prompt_tokens = mx.array([0, 9, 10, 3, 4, 7, 8]) + + print("\nTest with more offsets:") + print(f"current_tokens: {current_tokens}") + print(f"prompt_tokens: {prompt_tokens}") + + result = CacheWrapper._find_matching_sequence_length( + current_tokens, prompt_tokens + ) + self.assertEqual(result, 0) + + result = CacheWrapper._find_matching_sequence_length( + current_tokens, + prompt_tokens, + start1=2, + start2=3, + ) + self.assertEqual(result, 2) + + def test_record_generated_token_loops(self): + cache = CacheWrapper( + model=DummyModel(), + max_kv_size=5, + keep=2, + ) + cache.tokens = mx.array([]) + cache.record_generated_token(1) + cache.record_generated_token(2) + cache.record_generated_token(3) + cache.record_generated_token(4) + cache.record_generated_token(5) + self.assertListEqual( + cache.tokens.tolist(), + [1, 2, 3, 4, 5], + ) + cache.record_generated_token(6) + self.assertListEqual( + cache.tokens.tolist(), + [1, 2, 4, 5, 6], + ) + + def test_cache_reuse_heavy(self): + cache = CacheWrapper(DummyModel(), 10, keep=2) + cache.cache[0] = ShiftingKVCache(max_size=10, keep=2) + + # set up pretend cache + cached_tokens = mx.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + cache_kv = self.make_random_kv(10) + cache.tokens = cached_tokens + cache.cache[0].update_and_fetch(cache_kv, cache_kv) + + # set up pretend prompt + prompt_tokens = mx.array([1, 2, 4, 7, 8, 9, 11]) + + prefix_len = cache._find_matching_sequence_length( + cached_tokens, prompt_tokens, 0 + ) + self.assertEqual(prefix_len, 2) + + total_reused = cache._truncate_cache( + prompt_tokens=prompt_tokens, + common_prefix_len=prefix_len, + non_prefix_reuse_min_seq_len=1, + ) + + # prepare references + def idx(v, a, b): + return v[:, :, a:b, :] + + should_be_tokens = mx.array([1, 2, 4, 7, 8, 9]) + should_be_kv = mx.concatenate( + [ + idx(cache_kv, 0, 2), + idx(cache_kv, 3, 4), + idx(cache_kv, 6, 9), + ], + axis=2, + ) + + self.assertEqual(total_reused, 4) + self.assertArrEqual(cache.tokens, should_be_tokens) + self.assertArrEqual(cache.cache[0].keys, should_be_kv) + + # ensure updating works as intended + new_kv = self.make_random_kv(1) + keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) + should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) + self.assertArrEqual(keys, should_be_kv) + + # ensure batch concat works as intended + new_kv = self.make_random_kv(2) + keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) + should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) + self.assertArrEqual(keys, should_be_kv) + + def test_cache_reuse_overwrite_heavy(self): + cache = CacheWrapper(DummyModel(), 10, keep=2) + cache.cache[0] = ShiftingKVCache(max_size=10, keep=2) + + # set up pretend cache + cached_tokens = mx.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + cache_kv = self.make_random_kv(10) + for i in range(10): + cache.record_generated_token(cached_tokens[i]) + cache.cache[0].update_and_fetch(cache_kv, cache_kv) + + # append another one to overwrite + cache.record_generated_token(11) + cache_new_kv = self.make_random_kv(1) + cache.cache[0].update_and_fetch(cache_new_kv, cache_new_kv) + + print(cache.tokens) + self.assertArrEqual(cache.tokens, mx.array([1, 2, 4, 5, 6, 7, 8, 9, 10, 11])) + self.assertEqual(cache.cache[0].keys.shape[2], 10) + + # set up pretend prompt + prompt_tokens = mx.array([1, 2, 4, 7, 8, 9, 12]) + + prefix_len = cache._find_matching_sequence_length( + cached_tokens, prompt_tokens, 0 + ) + self.assertEqual(prefix_len, 2) + + # prepare references + def idx(v, a, b): + return v[:, :, a:b, :] + + should_be_tokens = mx.array([1, 2, 4, 7, 8, 9]) + should_be_kv = mx.concatenate( + [ + idx(cache_kv, 0, 2), + idx(cache_kv, 3, 4), + idx(cache_kv, 6, 9), + ], + axis=2, + ) + + total_reused = cache._truncate_cache( + prompt_tokens=prompt_tokens, + common_prefix_len=prefix_len, + non_prefix_reuse_min_seq_len=1, + ) + + self.assertEqual(total_reused, 4) + self.assertArrEqual(cache.tokens, should_be_tokens) + self.assertArrEqual(cache.cache[0].keys, should_be_kv) + + # ensure updating works as intended + new_kv = self.make_random_kv(1) + keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) + should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) + self.assertArrEqual(keys, should_be_kv) + + # ensure batch concat works as intended + new_kv = self.make_random_kv(2) + keys, _ = cache.cache[0].update_and_fetch(new_kv, new_kv) + should_be_kv = mx.concatenate([should_be_kv, new_kv], axis=2) + self.assertArrEqual(keys, should_be_kv) + + def test_update_cache_heavy(self): + """Test that the cache updates correctly during generation""" + # TODO(christian-lms): you need to pipe in nonprefix reuse min seq len + model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit") + model_kit = load_model(model_path=model_path, max_kv_size=10) + + # set up pretend cache + prompt_tokens = mx.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + non_prefill_tokens = model_kit.cache_wrapper.update_cache(prompt_tokens, prompt_progress_callback=None, keep=2) + layer_0_cache = model_kit.cache_wrapper.cache[0] + from copy import deepcopy + original_keys = deepcopy(layer_0_cache.state[0]) + + # generate a few tokens + for result in create_generator( + model_kit=model_kit, + prompt_tokens=prompt_tokens, + seed=0, + max_tokens=2, + temp=0.0, + prompt_progress_callback=None, + keep=2, + ): + print(model_kit.cache_wrapper.tokens.tolist()) + print(result.tokens) + + result_tokens = mx.array([1, 2, 6, 7, 8, 9, 10, 4999, 1725, 1725]) + self.assertArrEqual(model_kit.cache_wrapper.tokens, result_tokens) + + _compA = model_kit.cache_wrapper.cache[0]._temporal_order(model_kit.cache_wrapper.cache[0].state[0]) + compA = _compA[..., :7, :] + print(_compA[0,0,:,:1].tolist()) + compB = mx.concat( + [original_keys[..., :2, :], original_keys[..., 4:, :]], axis=2) + self.assertArrEqual(compA, compB) + print("--- ---") + + new_prompt_tokens = mx.array([1, 2, 8, 9, 10, 4999, 1725, 1725]) + for result in create_generator( + model_kit=model_kit, + prompt_tokens=new_prompt_tokens, + seed=0, + max_tokens=2, + temp=0.0, + prompt_progress_callback=None, + keep=2, + ): + self.assertEqual(len(model_kit.cache_wrapper.tokens), model_kit.cache_wrapper.cache[0].state[0].shape[2]) + print(f"HOASDOSIADN {result.tokens}") + print(model_kit.cache_wrapper.tokens.tolist()) + print(result.tokens) + + print(model_kit.cache_wrapper.tokens.tolist()) + new_result_tokens = mx.array([1, 2, 9, 10, 4999, 1725, 1725, 21002, 1177, 1177]) + self.assertArrEqual(model_kit.cache_wrapper.tokens, new_result_tokens) + + _compC = model_kit.cache_wrapper.cache[0]._temporal_order(model_kit.cache_wrapper.cache[0].state[0]) + compC = _compC[..., :3, :] + print(_compC[0,0,:,:1].tolist()) + print(original_keys[0,0,:,:1].tolist()) + compD = mx.concat( + [original_keys[..., :2, :], original_keys[..., 8:, :]], axis=2) + self.assertArrEqual(compC, compD) + compE = _compC[..., 3:6, :] + compF = _compA[..., 7:, :] + print("--- ---") + print(_compC[0,0,2:5,:1].tolist()) + print(_compA[0,0,7:,:1].tolist()) + self.assertArrEqual(compE, compF) + raise ValueError() if __name__ == "__main__": diff --git a/tests/utils.py b/tests/utils.py index 6d1c95d..ad3fd8e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,12 @@ from mlx_engine.generate import load_model, load_draft_model, tokenize +class DummyModel: + """Dummy model class for testing""" + + layers = [0] + + def model_getter(model_name: str): """Helper method to get a model, prompt user to download if not found"""