Skip to content

[BUG] Qwen 3.5 VLM model crashes on subsequent requests with manual KV cache prefix matching #157

@neon52

Description

@neon52

Describe the bug
When using VLMModelFactory to load a multimodal model (specifically Qwen 3.5), reusing a ModelContainer for multiple sequential requests fails if manual KV cache prefix matching is employed.

While this pattern (finding a matching KV cache "slot" and initializing a TokenIterator with remainingTokens) works reliably for text-only models, VLM models appear to maintain internal state that is not cleared or updated correctly between perform blocks unless the entire ModelContainer is reloaded via VLMModelFactory.shared.loadContainer().

To Reproduce

  1. Load a Qwen 3.5 VLM model (for example: https://huggingface.co/mlx-community/Qwen3.5-9B-MLX-4bit) using VLMModelFactory.shared.loadContainer.
  2. Perform the first inference request using any KV cache (it will always succeed since it's a fresh ModelContainer)
  3. For the 2nd request load a KV cache that is different than what was used for the 1st request and attempt to initialize a TokenIterator for the remaining tokens.

On step 3 the app will crash with a broadcast_shapes mismatch error. Note that I can workaround the issue by completely reloading the model before step 3, but this somewhat defeats the purpose of KV cache reuse.

Example code snippet

// Current workaround: uncommenting the lines below for VLM models only makes everything work
//let freshModelContainer = try await VLMModelFactory.shared.loadContainer(configuration: modelContainer.configuration)
//modelContainer = freshModelContainer

// Find the best matching KV cache or create a new one
let preparedUserInput = try await modelContainer.prepare(input: userInput)
let preparedTokens = preparedUserInput.text.tokens.asArray(Int.self)
let (bestCache, remainingTokens) = findBestCache(preparedTokens)

// Perform inference using the loaded KV cache
await modelContainer.perform { context in
    let tokenIterator = try TokenIterator(
        input: LMInput(tokens: MLXArray(remainingTokens).reshaped([1, -1])), 
        model: context.model,
        cache: bestCache
    )
    generateTokenTask(
        promptTokenCount: preparedTokens.count, // use the full prompt size here instead of remainingTokens.count so the stats are correct
        modelConfiguration: context.configuration,
        tokenizer: context.tokenizer,
        iterator: tokenIterator
    )
}

I should reiterate that this same code works fine for multiple requests using all regular LLM models that I have tried, so the issue appears to be unique to VLM models. It can also work reliably for VLM models the workaround in the code snippet above, but that makes things much slower.

Expected behavior
VLMModel should allow for TokenIterator initialization from an existing KVCache state, provided the prefix matches, without requiring a full container reload. Internal vision-related states (like image grid offsets or positional indicators) should be reset or exposed so they can be managed alongside the KV cache.

Desktop (please complete the following information):

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions