@@ -220,6 +220,77 @@ class SchedulerTokenHandlerTests: XCTestCase {
220220 }
221221 }
222222
223+ // MARK: - Stop Token Accounting
224+
225+ /// Verifies that when `includeStopToken: true`, the stop token is included
226+ /// in the stream output count — matching the accounting fix in
227+ /// InferenceScheduler where tokenCount/generatedTokenIds must include it.
228+ func testRawTokenHandlerIncludeStopTokenCountsInOutput( ) async {
229+ let ( stream, continuation) = AsyncStream< TokenGeneration> . makeStream( )
230+
231+ let handler = SchedulerTokenHandler . rawToken (
232+ continuation: continuation,
233+ includeStopToken: true
234+ )
235+
236+ // Verify mode allows the scheduler to gate on it
237+ if case . rawTokens( let includeStop) = handler. mode {
238+ XCTAssertTrue ( includeStop)
239+ } else {
240+ XCTFail ( " Expected .rawTokens mode " )
241+ }
242+
243+ XCTAssertTrue ( handler. processToken ( 10 ) )
244+ XCTAssertTrue ( handler. processToken ( 20 ) )
245+ // Stop token should be emitted and counted
246+ XCTAssertTrue ( handler. processStopToken ( 0 ) )
247+ handler. finish ( )
248+
249+ var allTokens = [ Int] ( )
250+ for await gen in stream {
251+ if case . token( let id) = gen {
252+ allTokens. append ( id)
253+ }
254+ }
255+
256+ // 2 regular tokens + 1 stop token = 3 total
257+ XCTAssertEqual ( allTokens, [ 10 , 20 , 0 ] )
258+ XCTAssertEqual ( allTokens. count, 3 , " Stop token must be counted in output " )
259+ }
260+
261+ /// Verifies that when `includeStopToken: false`, the stop token is NOT in
262+ /// the stream — the scheduler should not count it in tokenCount either.
263+ func testRawTokenHandlerExcludeStopTokenOmitsFromOutput( ) async {
264+ let ( stream, continuation) = AsyncStream< TokenGeneration> . makeStream( )
265+
266+ let handler = SchedulerTokenHandler . rawToken (
267+ continuation: continuation,
268+ includeStopToken: false
269+ )
270+
271+ if case . rawTokens( let includeStop) = handler. mode {
272+ XCTAssertFalse ( includeStop)
273+ } else {
274+ XCTFail ( " Expected .rawTokens mode " )
275+ }
276+
277+ XCTAssertTrue ( handler. processToken ( 10 ) )
278+ XCTAssertTrue ( handler. processToken ( 20 ) )
279+ XCTAssertTrue ( handler. processStopToken ( 0 ) )
280+ handler. finish ( )
281+
282+ var allTokens = [ Int] ( )
283+ for await gen in stream {
284+ if case . token( let id) = gen {
285+ allTokens. append ( id)
286+ }
287+ }
288+
289+ // Only 2 regular tokens, stop token omitted
290+ XCTAssertEqual ( allTokens, [ 10 , 20 ] )
291+ XCTAssertEqual ( allTokens. count, 2 , " Stop token must NOT be counted in output " )
292+ }
293+
223294 // MARK: - Cancellation
224295
225296 func testOnCancellationCallbackFires( ) async {
@@ -236,12 +307,12 @@ class SchedulerTokenHandlerTests: XCTestCase {
236307 expectation. fulfill ( )
237308 }
238309
239- // Trigger cancellation by dropping the stream consumer
240- // (this calls finish which triggers onTermination)
241- continuation. finish ( )
310+ // Start a consumer task then cancel it — this triggers .cancelled
311+ let task = Task {
312+ for await _ in stream { }
313+ }
314+ task. cancel ( )
242315
243- // The stream should complete; the onTermination is only triggered on
244- // .cancelled, not .finished. So we just verify it doesn't crash.
245- for await _ in stream { }
316+ await fulfillment ( of: [ expectation] , timeout: 2.0 )
246317 }
247318}
0 commit comments