-
Notifications
You must be signed in to change notification settings - Fork 104
Description
Describe the bug
When using VLMModelFactory to load a multimodal model (specifically Qwen 3.5), reusing a ModelContainer for multiple sequential requests fails if manual KV cache prefix matching is employed.
While this pattern (finding a matching KV cache "slot" and initializing a TokenIterator with remainingTokens) works reliably for text-only models, VLM models appear to maintain internal state that is not cleared or updated correctly between perform blocks unless the entire ModelContainer is reloaded via VLMModelFactory.shared.loadContainer().
To Reproduce
- Load a Qwen 3.5 VLM model (for example: https://huggingface.co/mlx-community/Qwen3.5-9B-MLX-4bit) using VLMModelFactory.shared.loadContainer.
- Perform the first inference request using any KV cache (it will always succeed since it's a fresh ModelContainer)
- For the 2nd request load a KV cache that is different than what was used for the 1st request and attempt to initialize a TokenIterator for the remaining tokens.
On step 3 the app will crash with a broadcast_shapes mismatch error. Note that I can workaround the issue by completely reloading the model before step 3, but this somewhat defeats the purpose of KV cache reuse.
Example code snippet
// Current workaround: uncommenting the lines below for VLM models only makes everything work
//let freshModelContainer = try await VLMModelFactory.shared.loadContainer(configuration: modelContainer.configuration)
//modelContainer = freshModelContainer
// Find the best matching KV cache or create a new one
let preparedUserInput = try await modelContainer.prepare(input: userInput)
let preparedTokens = preparedUserInput.text.tokens.asArray(Int.self)
let (bestCache, remainingTokens) = findBestCache(preparedTokens)
// Perform inference using the loaded KV cache
await modelContainer.perform { context in
let tokenIterator = try TokenIterator(
input: LMInput(tokens: MLXArray(remainingTokens).reshaped([1, -1])),
model: context.model,
cache: bestCache
)
generateTokenTask(
promptTokenCount: preparedTokens.count, // use the full prompt size here instead of remainingTokens.count so the stats are correct
modelConfiguration: context.configuration,
tokenizer: context.tokenizer,
iterator: tokenIterator
)
}I should reiterate that this same code works fine for multiple requests using all regular LLM models that I have tried, so the issue appears to be unique to VLM models. It can also work reliably for VLM models the workaround in the code snippet above, but that makes things much slower.
Expected behavior
VLMModel should allow for TokenIterator initialization from an existing KVCache state, provided the prefix matches, without requiring a full container reload. Internal vision-related states (like image grid offsets or positional indicators) should be reset or exposed so they can be managed alongside the KV cache.
Desktop (please complete the following information):
- OS Version: MacOS 15
- Device: M4 MacBook Pro
- mlx-swift-lm Version: latest main from 3 days ago (https://github.com/ml-explore/mlx-swift-lm/tree/3a7503df30e2cd100be933322dcf9b25d6459803)