Skip to content

Commit 1df5086

Browse files
committed
improve dual path routing
1 parent e9a3730 commit 1df5086

File tree

5 files changed

+191
-10
lines changed

5 files changed

+191
-10
lines changed

Libraries/MLXLMCommon/Batching/InferenceScheduler.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ public actor InferenceScheduler {
565565
/// Check if a request is compatible with batch generation.
566566
///
567567
/// Returns `false` for:
568-
/// - VLMs (input contains images or video)
568+
/// - Multimodal inputs (images or video)
569569
/// - Hybrid SSM models (cache contains `MambaCache` or `CacheList`)
570570
/// - Requests with `kvBits` set (QuantizedKVCache incompatible)
571571
/// - Caches containing `QuantizedKVCache`
@@ -578,7 +578,7 @@ public actor InferenceScheduler {
578578
cache: [KVCache]?,
579579
model: any LanguageModel
580580
) -> Bool {
581-
// VLM check: images or video present
581+
// Multimodal check: images or video present
582582
if input.image != nil || input.video != nil {
583583
return false
584584
}

Libraries/MLXLMCommon/ModelContainer.swift

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import Tokenizers
3333
/// ```
3434
public final class ModelContainer: Sendable {
3535
private let context: SerialAccessContainer<ModelContext>
36+
private let loadedAsVLM: Bool
3637

3738
/// Optional inference scheduler for transparent batching support.
3839
///
@@ -71,6 +72,7 @@ public final class ModelContainer: Sendable {
7172
}
7273

7374
public init(context: consuming ModelContext, scheduler: InferenceScheduler? = nil) {
75+
self.loadedAsVLM = context.loadedAsVLM
7476
self.context = .init(context)
7577
self.scheduler = scheduler
7678
}
@@ -196,10 +198,10 @@ public final class ModelContainer: Sendable {
196198
let input = SendableBox(input)
197199

198200
// When a scheduler is set, route through InferenceScheduler for
199-
// transparent batching. The scheduler handles batch compatibility
200-
// checks internally — incompatible requests (VLMs, kvBits, SSM models)
201-
// automatically fall back to the single TokenIterator path.
202-
if let scheduler {
201+
// transparent batching. VLMs are excluded at this level (!loadedAsVLM);
202+
// the scheduler handles remaining compatibility checks (multimodal
203+
// inputs, kvBits, SSM models) and falls back to single TokenIterator.
204+
if let scheduler, !loadedAsVLM {
203205
let lmInput = input.consume()
204206

205207
// Read model, tokenizer, and configuration from the context.

Libraries/MLXLMCommon/ModelFactory.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,18 @@ public struct ModelContext {
6868
public var model: any LanguageModel
6969
public var processor: any UserInputProcessor
7070
public var tokenizer: Tokenizer
71+
public var loadedAsVLM: Bool
7172

7273
public init(
7374
configuration: ModelConfiguration, model: any LanguageModel,
74-
processor: any UserInputProcessor, tokenizer: any Tokenizer
75+
processor: any UserInputProcessor, tokenizer: any Tokenizer,
76+
loadedAsVLM: Bool = false
7577
) {
7678
self.configuration = configuration
7779
self.model = model
7880
self.processor = processor
7981
self.tokenizer = tokenizer
82+
self.loadedAsVLM = loadedAsVLM
8083
}
8184
}
8285

@@ -364,11 +367,11 @@ final public class ModelFactoryRegistry: @unchecked Sendable {
364367
private init() {
365368
self.trampolines = [
366369
{
367-
(NSClassFromString("MLXVLM.TrampolineModelFactory") as? ModelFactoryTrampoline.Type)?
370+
(NSClassFromString("MLXLLM.TrampolineModelFactory") as? ModelFactoryTrampoline.Type)?
368371
.modelFactory()
369372
},
370373
{
371-
(NSClassFromString("MLXLLM.TrampolineModelFactory") as? ModelFactoryTrampoline.Type)?
374+
(NSClassFromString("MLXVLM.TrampolineModelFactory") as? ModelFactoryTrampoline.Type)?
372375
.modelFactory()
373376
},
374377
]

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ public final class VLMModelFactory: ModelFactory {
377377

378378
return .init(
379379
configuration: mutableConfiguration, model: model, processor: processor,
380-
tokenizer: tokenizer)
380+
tokenizer: tokenizer, loadedAsVLM: true)
381381
}
382382

383383
}
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import Foundation
4+
import MLX
5+
@preconcurrency @testable import MLXLMCommon
6+
import MLXNN
7+
import Tokenizers
8+
import XCTest
9+
10+
// MARK: - Factory Resolution Order Tests
11+
12+
class DualPathRoutingTests: XCTestCase {
13+
14+
/// Verify that ModelFactoryRegistry lists LLM before VLM by default.
15+
///
16+
/// The default trampoline order should try MLXLLM first, then MLXVLM.
17+
/// This ensures dual-path models (e.g. Qwen 3.5) resolve as LLM
18+
/// when loaded via the generic `loadModel`/`loadModelContainer` APIs.
19+
func testFactoryRegistryPrefersLLMOverVLM() {
20+
let factories = ModelFactoryRegistry.shared.modelFactories()
21+
22+
// Both factories should be available in the test environment
23+
guard factories.count >= 2 else {
24+
// In unit test context without both modules linked, we can at least
25+
// verify the trampoline array order via the registry's public API.
26+
// If only one factory is available, the ordering test is moot.
27+
return
28+
}
29+
30+
// The first factory should be the LLM factory.
31+
// LLMModelFactory's modelRegistry is LLMRegistry; VLMModelFactory's is VLMRegistry.
32+
let firstFactory = factories[0]
33+
let secondFactory = factories[1]
34+
35+
// LLMModelFactory uses LLMRegistry, VLMModelFactory uses VLMRegistry.
36+
// We distinguish by checking the type name of the model registry.
37+
let firstName = String(describing: type(of: firstFactory))
38+
let secondName = String(describing: type(of: secondFactory))
39+
40+
XCTAssertTrue(
41+
firstName.contains("LLM"),
42+
"First factory should be LLM, got \(firstName)")
43+
XCTAssertTrue(
44+
secondName.contains("VLM"),
45+
"Second factory should be VLM, got \(secondName)")
46+
}
47+
48+
// MARK: - VLM-Loaded Container Bypasses Scheduler
49+
50+
/// A minimal mock model for testing the VLM guard in ModelContainer.generate().
51+
private class MinimalMockModel: Module, LanguageModel, KVCacheDimensionProvider,
52+
@unchecked Sendable
53+
{
54+
let vocabSize = 32
55+
var kvHeads: [Int] { [4] }
56+
57+
func prepare(_ input: LMInput, cache: [KVCache], windowSize: Int?) throws -> PrepareResult {
58+
.tokens(input.text)
59+
}
60+
61+
func callAsFunction(
62+
_ input: LMInput.Text, cache: [KVCache]?, state: LMOutput.State?
63+
) -> LMOutput {
64+
let B = input.tokens.dim(0)
65+
let S = input.tokens.dim(1)
66+
// Return logits with token 0 as the highest probability (will hit EOS quickly)
67+
var flat = [Float](repeating: -100.0, count: B * S * vocabSize)
68+
for i in stride(from: 0, to: flat.count, by: vocabSize) {
69+
flat[i] = 0.0 // token 0 = EOS
70+
}
71+
return LMOutput(logits: MLXArray(flat, [B, S, vocabSize]))
72+
}
73+
74+
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
75+
weights
76+
}
77+
}
78+
79+
/// Verify that a VLM-loaded ModelContainer with a scheduler set
80+
/// bypasses the scheduler and uses the direct TokenIterator path.
81+
func testVLMLoadedContainerBypassesScheduler() async throws {
82+
try skipIfMetalUnavailable()
83+
let model = MinimalMockModel()
84+
let tokenizer = TestTokenizer()
85+
let config = ModelConfiguration(id: "test-vlm-model")
86+
let processor = TestInputProcessor()
87+
88+
// Create a ModelContext with loadedAsVLM = true
89+
let context = ModelContext(
90+
configuration: config,
91+
model: model,
92+
processor: processor,
93+
tokenizer: tokenizer,
94+
loadedAsVLM: true
95+
)
96+
97+
// Create container WITH a scheduler — should be bypassed for VLM
98+
let scheduler = InferenceScheduler()
99+
let container = ModelContainer(context: context, scheduler: scheduler)
100+
101+
// The scheduler should be set on the container
102+
XCTAssertNotNil(container.scheduler, "Scheduler should be set on container")
103+
104+
// Submit a text-only request
105+
let input = LMInput(tokens: MLXArray([Int32(1), Int32(2), Int32(3)]))
106+
let params = GenerateParameters(maxTokens: 3, temperature: 0)
107+
108+
let stream = try await container.generate(
109+
input: input,
110+
parameters: params
111+
)
112+
113+
// The scheduler should NOT have been used — its state should still be idle
114+
let schedulerState = await scheduler.currentState
115+
XCTAssertEqual(
116+
schedulerState, "idle",
117+
"Scheduler should remain idle when container is VLM-loaded, got: \(schedulerState)")
118+
119+
// Consume the stream to verify it completes (via direct TokenIterator path)
120+
var receivedOutput = false
121+
for await generation in stream {
122+
if generation.chunk != nil || generation.info != nil {
123+
receivedOutput = true
124+
}
125+
}
126+
XCTAssertTrue(receivedOutput, "Should receive output via direct TokenIterator path")
127+
}
128+
129+
/// Verify that a non-VLM ModelContainer with a scheduler actually uses the scheduler.
130+
func testLLMLoadedContainerUsesScheduler() async throws {
131+
try skipIfMetalUnavailable()
132+
let model = MinimalMockModel()
133+
let tokenizer = TestTokenizer()
134+
let config = ModelConfiguration(id: "test-llm-model")
135+
let processor = TestInputProcessor()
136+
137+
// Create a ModelContext with loadedAsVLM = false (default)
138+
let context = ModelContext(
139+
configuration: config,
140+
model: model,
141+
processor: processor,
142+
tokenizer: tokenizer
143+
)
144+
145+
let scheduler = InferenceScheduler()
146+
let container = ModelContainer(context: context, scheduler: scheduler)
147+
148+
let input = LMInput(tokens: MLXArray([Int32(1), Int32(2), Int32(3)]))
149+
let params = GenerateParameters(maxTokens: 3, temperature: 0)
150+
151+
let stream = try await container.generate(
152+
input: input,
153+
parameters: params
154+
)
155+
156+
// The scheduler should have been used — its state should NOT be idle
157+
let schedulerState = await scheduler.currentState
158+
XCTAssertNotEqual(
159+
schedulerState, "idle",
160+
"Scheduler should be active for LLM-loaded container, got: \(schedulerState)")
161+
162+
// Consume the stream
163+
for await _ in stream {}
164+
}
165+
166+
/// Verify that ModelContext defaults loadedAsVLM to false.
167+
func testModelContextDefaultsLoadedAsVLMToFalse() {
168+
let context = ModelContext(
169+
configuration: ModelConfiguration(id: "test"),
170+
model: MinimalMockModel(),
171+
processor: TestInputProcessor(),
172+
tokenizer: TestTokenizer()
173+
)
174+
XCTAssertFalse(context.loadedAsVLM, "loadedAsVLM should default to false")
175+
}
176+
}

0 commit comments

Comments
 (0)