Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f11604b
loop record generated token
christian-lms Jul 7, 2025
64113b3
shifting kv cache
christian-lms Jul 7, 2025
19cde88
override another method to use rope shift
christian-lms Jul 7, 2025
f4822f0
add testing asserts
christian-lms Jul 7, 2025
e297092
warn
christian-lms Jul 7, 2025
ee316db
move cache into a separate file
christian-lms Jul 7, 2025
3df5d31
begin raw unit tests
christian-lms Jul 7, 2025
514b6c5
initial uncommented tests
christian-lms Jul 7, 2025
3f64566
manual rope stuff
christian-lms Jul 7, 2025
22d8a83
pipe in n_keep
christian-lms Jul 8, 2025
d165867
a few more cache wrapper tests
christian-lms Jul 8, 2025
199d231
ruff formatting lmao
christian-lms Jul 8, 2025
8dcb82d
record_generated_token test and fix
christian-lms Jul 8, 2025
56d34e0
prelim reuse code
christian-lms Jul 8, 2025
3741509
maybe reuse unit test
christian-lms Jul 8, 2025
a677f86
code reuse!
christian-lms Jul 8, 2025
77e523c
cache shift test comments
christian-lms Jul 8, 2025
2a8855a
stop rope shifting values and set keep
christian-lms Jul 8, 2025
447a134
cache is a list, and exclude tokens in the right place
christian-lms Jul 8, 2025
16fc7a1
same for tests
christian-lms Jul 8, 2025
85f2241
apply that to tests too oops
christian-lms Jul 8, 2025
6ac8d2f
decouple from rotatingkvcache since so much of it was rewritten anywa…
christian-lms Jul 8, 2025
d124c0e
working reuse test
christian-lms Jul 8, 2025
8ac2bae
cache offsets ooooooooooooooooooooops
christian-lms Jul 8, 2025
e373181
refactor trim/temporal order internal interfaces to operate on both k…
christian-lms Jul 8, 2025
4b938be
more test fixes
christian-lms Jul 8, 2025
d7c4ce7
refactor tests
christian-lms Jul 8, 2025
5d60f13
technically if you ran this it would work
christian-lms Jul 8, 2025
642a8c3
Merge branch 'lmstudio-ai:main' into christian/cache_reuse_again
christian-lms Jul 10, 2025
9c378e6
properly works now (i think)
christian-lms Jul 10, 2025
aed67a9
try to remove rope
christian-lms Jul 10, 2025
1b9af40
simplify cache again
christian-lms Jul 10, 2025
a1521d8
more reductionism
christian-lms Jul 10, 2025
dd205ba
remove prints
christian-lms Jul 10, 2025
a740335
??? oops
christian-lms Jul 10, 2025
349b5a8
final fixes for now
christian-lms Jul 10, 2025
1c4cf24
more fixes
christian-lms Jul 10, 2025
bf66e2b
make linter happy
christian-lms Jul 10, 2025
3d51d58
fix trim
christian-lms Jul 11, 2025
18b6dc3
extra tests (in progress)
christian-lms Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions mlx_engine/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import List, Optional, Any

from mlx_lm.models.cache import RotatingKVCache, KVCache
import mlx.core as mx
import mlx.nn as nn


class ShiftingKVCache(RotatingKVCache):
def __init__(self, max_size=256, keep=0, step=256):
self.reuse_queue = []
super().__init__(max_size=max_size, keep=keep, step=step)

def reuse_section(
self, write_start_idx: int, reuse_start_idx: int, reuse_length: int
) -> None:
# queue for reuse: everything is done in one pass at the end in do_reuse
self.reuse_queue.append((write_start_idx, reuse_start_idx, reuse_length))

def do_reuse(self) -> None:
if not self.reuse_queue:
return

# just in case maybe
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)

# just in case, sort in write order
self.reuse_queue.sort(key=lambda x: x[0])

key_segments = []
value_segments = []
current_pos = 0

for write_start_idx, reuse_start_idx, reuse_length in self.reuse_queue:
# add any gap before this write position
if current_pos < write_start_idx:
key_segments.append(self.keys[..., current_pos:write_start_idx, :])
value_segments.append(self.values[..., current_pos:write_start_idx, :])

reuse_end_idx = reuse_start_idx + reuse_length
current_pos = write_start_idx + reuse_length

key_segments.append(self.keys[..., reuse_start_idx:reuse_end_idx, :])
value_segments.append(self.values[..., reuse_start_idx:reuse_end_idx, :])

self.keys = mx.concatenate(key_segments, axis=2)
self.values = mx.concatenate(value_segments, axis=2)

# clean up
self.reuse_queue = []
self._idx = self.keys.shape[2]
self.offset = self.keys.shape[2]

def trim(self, n) -> int:
# trim must not respect keep
n = min(self.offset, n)
if n <= 0:
return 0

# put us back into the state before the circular buffer is full
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)

new_length = max(self.keys.shape[2] - n, 0)
self.keys = self.keys[..., :new_length, :]
self.values = self.values[..., :new_length, :]

self.offset = new_length
self._idx = new_length
return n

def set_keep(self, keep):
# kv must be in temporal order, else we will keep the wrong thing
if self.keys is not None:
self.keys = self._temporal_order(self.keys)
if self.values is not None:
self.values = self._temporal_order(self.values)
self.keep = keep

def is_trimmable(self) -> bool:
return True


def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
keep: int = 4,
) -> List[Any]:
"""
Construct the model's cache for use in generation.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``ShiftingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
ShiftingKVCache(max_size=max_kv_size, keep=keep) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
165 changes: 124 additions & 41 deletions mlx_engine/cache_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from typing import List, Optional, Any

from mlx_engine.logging import log_info, log_warn, log_error
from mlx_lm.models.cache import (
make_prompt_cache,
trim_prompt_cache,
can_trim_prompt_cache,
)
from mlx_engine.cache import make_prompt_cache
from mlx_lm.models.cache import trim_prompt_cache, can_trim_prompt_cache
from mlx_lm.generate import generation_stream, maybe_quantize_kv_cache
import mlx.core as mx
import mlx.nn as nn
Expand All @@ -26,6 +23,7 @@ def __init__(
kv_bits: Optional[int] = None,
kv_group_size: Optional[int] = None,
quantized_kv_start: Optional[int] = None,
keep: int = 4,
):
"""
Initialize the CacheWrapper.
Expand All @@ -36,7 +34,8 @@ def __init__(
"""
# utilize a simple ordered list of tokens processed so far for cache invalidation checking
self.tokens: Optional[mx.array] = None
self.cache: List[Any] = make_prompt_cache(model, max_kv_size)
self.keep = keep
self.cache: List[Any] = make_prompt_cache(model, max_kv_size, keep)
self.model = model
self.draft_model: Optional[nn.Module] = None
self.max_kv_size = max_kv_size
Expand All @@ -48,41 +47,87 @@ def __init__(
)

@staticmethod
def _find_common_prefix(
current_tokens: mx.array, prompt_tokens: mx.array, num_tokens_to_exclude: int
def _find_matching_sequence_length(
tokens1: mx.array,
tokens2: mx.array,
start1: int = 0,
start2: int = 0,
) -> int:
"""
Determine the common prefix length between the current tokens and the prompt tokens.
Find the length of matching token sequence between two token arrays.

Args:
current_tokens (mx.array): The cached tokens (self.tokens).
prompt_tokens (mx.array): The prompt tokens.
num_tokens_to_exclude (int): The minimum length of the remaining prompt tokens array.
tokens1: First token array
start1: Starting position in first array
tokens2: Second token array
start2: Starting position in second array

Returns:
int: The length of the common prefix.
int: Length of matching sequence
"""
prompt_tokens = prompt_tokens
current_tokens = current_tokens
# Find the minimum length between the two arrays
min_length = min(len(current_tokens), len(prompt_tokens))

# Compare elements up to the minimum length
mask = prompt_tokens[:min_length] == current_tokens[:min_length]

# Find the index where the first mismatch occurs
if mx.any(mask == False): # noqa E712
common_length = int(mx.argmax(mask == False)) # noqa E712
else:
common_length = int(min_length)

# Ensure that the prompt is at least num_tokens_to_exclude long
uncached_prompt_tokens_length = len(prompt_tokens[common_length:])
length_adjustment = max(
0, num_tokens_to_exclude - uncached_prompt_tokens_length
)
common_length = max(common_length - length_adjustment, 0)
return common_length
# Calculate actual bounds
max_len1 = len(tokens1) - start1
max_len2 = len(tokens2) - start2
min_length = int(min(max_len1, max_len2))

# Extract subsequences to compare
seq1 = tokens1[start1 : start1 + min_length]
seq2 = tokens2[start2 : start2 + min_length]

# Find first mismatch
mask = seq1 == seq2
return int(mx.argmax(mask == False)) if mx.any(mask == False) else min_length # noqa E712

def _truncate_cache(
self,
prompt_tokens: mx.array,
common_prefix_len: int,
non_prefix_reuse_min_seq_len: int = 256,
) -> int:
cache_size = len(self.tokens)
prompt_size = len(prompt_tokens)

# start scanning from after the common prefix
cache_head_idx = common_prefix_len
prompt_head_idx = common_prefix_len
total_reused = 0

if self.verbose:
print(
f"Looking for non-prefix sequences of length >= {non_prefix_reuse_min_seq_len}",
file=sys.stderr,
)

while cache_head_idx < cache_size and prompt_head_idx < prompt_size:
match_length = self._find_matching_sequence_length(
prompt_tokens, self.tokens, prompt_head_idx, cache_head_idx
)

if match_length < non_prefix_reuse_min_seq_len:
# sequence too short - advance cache pointer to find next potential match
cache_head_idx += 1
else:
if self.verbose:
print(f"Reusing {match_length} tokens from cache", file=sys.stderr)

# found reusable sequence - shift cache content
for cache in self.cache:
cache.reuse_section(prompt_head_idx, cache_head_idx, match_length)

# update the tokens to reflect the reused sequence
for i in range(match_length):
self.tokens[prompt_head_idx + i] = self.tokens[cache_head_idx + i]

# advance pointers
cache_head_idx += match_length
prompt_head_idx += match_length
total_reused += match_length

for cache in self.cache:
cache.do_reuse()
self.tokens = self.tokens[: common_prefix_len + total_reused]

return total_reused

def _get_unprocessed_tokens(
self, prompt_tokens: mx.array, num_tokens_to_exclude: int
Expand All @@ -102,20 +147,40 @@ def _get_unprocessed_tokens(
return self.tokens

# Find common KV between the last generation and the current prompt
common_prefix = self._find_common_prefix(
self.tokens, prompt_tokens, num_tokens_to_exclude
common_prefix = self._find_matching_sequence_length(
self.tokens,
prompt_tokens,
)

# do reuse but only if the cache has it
if hasattr(self.cache[0], "reuse_section"):
n_reused_tokens = self._truncate_cache(
prompt_tokens,
common_prefix,
)
log_info(
prefix="CacheWrapper",
message=f"Reused {n_reused_tokens} tokens from the cache",
)
common_prefix += n_reused_tokens

# exclude some tokens from end, e.g. for kicking off generation
if common_prefix >= len(prompt_tokens) - num_tokens_to_exclude:
common_prefix = len(prompt_tokens) - num_tokens_to_exclude

# Trim the cache if the common prefix is shorter than the current cache
num_tokens_to_trim = self.cache[0].offset - common_prefix
# state[0] is an alias for keys that accounts for partially filled buffers
num_tokens_to_trim = self.cache[0].state[0].shape[2] - common_prefix
if num_tokens_to_trim > 0:
if not can_trim_prompt_cache(self.cache):
log_warn(
prefix="CacheWrapper",
message=f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: "
"Cache is not trimmable. Clearing the cache instead.",
)
self.cache = make_prompt_cache(self.model, self.max_kv_size)
self.cache = make_prompt_cache(
self.model, self.max_kv_size, keep=self.keep
)
self.tokens = prompt_tokens
return self.tokens
tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim)
Expand All @@ -126,7 +191,9 @@ def _get_unprocessed_tokens(
message=f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected "
" ({num_tokens_to_trim}). Clearing the cache.",
)
self.cache = make_prompt_cache(self.model, self.max_kv_size)
self.cache = make_prompt_cache(
self.model, self.max_kv_size, keep=self.keep
)
self.tokens = prompt_tokens
return self.tokens
log_info(
Expand Down Expand Up @@ -221,9 +288,9 @@ def set_draft_model(self, draft_model: nn.Module):
message="Clearing current prompt cache and adding draft model to the cache",
)
self.tokens = None
self.cache: List[Any] = make_prompt_cache(self.model)
self.cache: List[Any] = make_prompt_cache(self.model, keep=self.keep)
if draft_model is not None:
self.cache += make_prompt_cache(draft_model)
self.cache += make_prompt_cache(draft_model, keep=self.keep)
self.draft_model = draft_model

def unset_draft_model(self):
Expand All @@ -239,6 +306,7 @@ def update_cache(
prompt_progress_callback,
*,
num_tokens_to_exclude: int = 1,
keep: int = 4,
) -> mx.array:
"""
Set up the KV cache for the next generation.
Expand All @@ -248,6 +316,7 @@ def update_cache(
prompt_tokens (mx.array): The prompt tokens.
prompt_progress_callback (Callable): A callback function to report prompt processing progress.
num_tokens_to_exclude (int): The number of tokens that should not be added to the cache.
keep (int): The number of tokens to always keep in the prefix of the prompt cache.

Returns:
mx.array: The prompt tokens to be used for the next generation.
Expand All @@ -257,6 +326,12 @@ def update_cache(
def prompt_progress_callback(x):
return None

# update keep tracking
self.keep = keep
for cache in self.cache:
if hasattr(cache, "set_keep"):
cache.set_keep(keep)

num_tokens_to_exclude = max(num_tokens_to_exclude, 1)
prompt_tokens = self._get_unprocessed_tokens(
prompt_tokens, num_tokens_to_exclude
Expand Down Expand Up @@ -296,5 +371,13 @@ def prompt_progress_callback(x):
def record_generated_token(self, token):
"""
Add the generated token to the token list, so that we can map the token to the KV cache.

Also loop when the cache does so that we accurately track what's in cache.
"""
# this behavior is common to rolling window (n_keep = 0) and truncate middle
# (n_keep > 0), and we should never get here with stop at max
if len(self.tokens) >= self.max_kv_size:
self.tokens = mx.concat(
[self.tokens[: self.keep], self.tokens[self.keep + 1 :]]
)
self.tokens = mx.concat([self.tokens, mx.array([token])])
4 changes: 4 additions & 0 deletions mlx_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def create_generator(
max_tokens: Optional[int] = 10000000,
speculative_decoding_toggle: Optional[bool] = None,
num_draft_tokens: Optional[int] = None,
keep: Optional[int] = 4,
) -> Iterator[GenerationResult]:
"""
Create a generator that streams text generation results from the model.
Expand Down Expand Up @@ -171,6 +172,8 @@ def create_generator(
if a draft model is loaded. If set to true, draft model must be loaded or else error.
If set to false, speculative decoding is disabled even if a draft model is loaded.
num_draft_tokens (Optional[int]): Number of tokens to draft when using speculative decoding
keep (Optional[int]): Number of tokens to always keep in the prefix of the prompt cache.
Defaults to 4, which is the minimum number of tokens needed for a valid prompt.

Yields:
GenerationResult: A named tuple containing:
Expand Down Expand Up @@ -218,6 +221,7 @@ def create_generator(
prompt_progress_callback,
generate_args,
speculative_decoding_toggle,
keep=keep,
)
if draft_model is None:
# input embeddings not yet supported for speculative decoding in mlx-lm
Expand Down
Loading
Loading