Skip to content

Commit 15ede77

Browse files
committed
Add raw token batching
1 parent 0ae9d19 commit 15ede77

File tree

2 files changed

+89
-6
lines changed

2 files changed

+89
-6
lines changed

Libraries/MLXLMCommon/Batching/InferenceScheduler.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,10 @@ public actor InferenceScheduler {
796796
}
797797

798798
if token == unknownTokenId || stopTokenIDs.contains(token) {
799+
if case .rawTokens(includeStopToken: true) = handler.mode {
800+
tokenCount += 1
801+
generatedTokenIds.append(token)
802+
}
799803
// For raw-token mode, emit stop token if requested
800804
_ = handler.processStopToken(token)
801805
stopReason = .stop
@@ -996,6 +1000,10 @@ public actor InferenceScheduler {
9961000
}
9971001

9981002
if token == unknownTokenId || stopTokenIDs.contains(token) {
1003+
if case .rawTokens(includeStopToken: true) = handler.mode {
1004+
tokenCount += 1
1005+
generatedTokenIds.append(token)
1006+
}
9991007
_ = handler.processStopToken(token)
10001008
stopReason = .stop
10011009
break
@@ -1326,6 +1334,10 @@ public actor InferenceScheduler {
13261334
if stopTokenIDs.contains(token)
13271335
|| token == tokenizer.unknownTokenId
13281336
{
1337+
if case .rawTokens(includeStopToken: true) = handler.mode {
1338+
tokenCounts[uid, default: 0] += 1
1339+
generatedTokenIds[uid, default: []].append(token)
1340+
}
13291341
// For raw-token mode, emit stop token if requested
13301342
_ = handler.processStopToken(token)
13311343
} else {

Tests/MLXLMTests/SchedulerTokenHandlerTests.swift

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,77 @@ class SchedulerTokenHandlerTests: XCTestCase {
220220
}
221221
}
222222

223+
// MARK: - Stop Token Accounting
224+
225+
/// Verifies that when `includeStopToken: true`, the stop token is included
226+
/// in the stream output count — matching the accounting fix in
227+
/// InferenceScheduler where tokenCount/generatedTokenIds must include it.
228+
func testRawTokenHandlerIncludeStopTokenCountsInOutput() async {
229+
let (stream, continuation) = AsyncStream<TokenGeneration>.makeStream()
230+
231+
let handler = SchedulerTokenHandler.rawToken(
232+
continuation: continuation,
233+
includeStopToken: true
234+
)
235+
236+
// Verify mode allows the scheduler to gate on it
237+
if case .rawTokens(let includeStop) = handler.mode {
238+
XCTAssertTrue(includeStop)
239+
} else {
240+
XCTFail("Expected .rawTokens mode")
241+
}
242+
243+
XCTAssertTrue(handler.processToken(10))
244+
XCTAssertTrue(handler.processToken(20))
245+
// Stop token should be emitted and counted
246+
XCTAssertTrue(handler.processStopToken(0))
247+
handler.finish()
248+
249+
var allTokens = [Int]()
250+
for await gen in stream {
251+
if case .token(let id) = gen {
252+
allTokens.append(id)
253+
}
254+
}
255+
256+
// 2 regular tokens + 1 stop token = 3 total
257+
XCTAssertEqual(allTokens, [10, 20, 0])
258+
XCTAssertEqual(allTokens.count, 3, "Stop token must be counted in output")
259+
}
260+
261+
/// Verifies that when `includeStopToken: false`, the stop token is NOT in
262+
/// the stream — the scheduler should not count it in tokenCount either.
263+
func testRawTokenHandlerExcludeStopTokenOmitsFromOutput() async {
264+
let (stream, continuation) = AsyncStream<TokenGeneration>.makeStream()
265+
266+
let handler = SchedulerTokenHandler.rawToken(
267+
continuation: continuation,
268+
includeStopToken: false
269+
)
270+
271+
if case .rawTokens(let includeStop) = handler.mode {
272+
XCTAssertFalse(includeStop)
273+
} else {
274+
XCTFail("Expected .rawTokens mode")
275+
}
276+
277+
XCTAssertTrue(handler.processToken(10))
278+
XCTAssertTrue(handler.processToken(20))
279+
XCTAssertTrue(handler.processStopToken(0))
280+
handler.finish()
281+
282+
var allTokens = [Int]()
283+
for await gen in stream {
284+
if case .token(let id) = gen {
285+
allTokens.append(id)
286+
}
287+
}
288+
289+
// Only 2 regular tokens, stop token omitted
290+
XCTAssertEqual(allTokens, [10, 20])
291+
XCTAssertEqual(allTokens.count, 2, "Stop token must NOT be counted in output")
292+
}
293+
223294
// MARK: - Cancellation
224295

225296
func testOnCancellationCallbackFires() async {
@@ -236,12 +307,12 @@ class SchedulerTokenHandlerTests: XCTestCase {
236307
expectation.fulfill()
237308
}
238309

239-
// Trigger cancellation by dropping the stream consumer
240-
// (this calls finish which triggers onTermination)
241-
continuation.finish()
310+
// Start a consumer task then cancel it — this triggers .cancelled
311+
let task = Task {
312+
for await _ in stream {}
313+
}
314+
task.cancel()
242315

243-
// The stream should complete; the onTermination is only triggered on
244-
// .cancelled, not .finished. So we just verify it doesn't crash.
245-
for await _ in stream {}
316+
await fulfillment(of: [expectation], timeout: 2.0)
246317
}
247318
}

0 commit comments

Comments
 (0)