@@ -302,3 +302,169 @@ struct OpenAILanguageModelTests {
302302 }
303303 }
304304}
305+
306+ // MARK: - Streaming Tool Call (mocked)
307+
308+ @Suite ( " OpenAI streaming tool calls (mocked) " )
309+ struct OpenAIStreamingToolCallTests {
310+ private let baseURL = URL ( string: " https://mock.openai.local " ) !
311+
312+ @Test ( . disabled( " Streaming mock under construction " ) ) func responsesStreamToolCallExecution( ) async throws {
313+ var responsesCallCount = 0
314+ URLProtocol . registerClass ( MockOpenAIEventStreamURLProtocol . self)
315+ MockOpenAIEventStreamURLProtocol . Handler. set { request in
316+ defer { responsesCallCount += 1 }
317+ let response = HTTPURLResponse (
318+ url: request. url!,
319+ statusCode: 200 ,
320+ httpVersion: nil ,
321+ headerFields: [ " Content-Type " : " text/event-stream " ]
322+ ) !
323+
324+ let events : [ String ]
325+ if responsesCallCount == 0 {
326+ events = [
327+ #"data: {"type":"response.tool_call.created","tool_call":{"id":"call_1","type":"function","function":{"name":"getWeather","arguments":""}}}"# ,
328+ #"data: {"type":"response.tool_call.delta","tool_call":{"id":"call_1","function":{"arguments":"{\"city\":\"San Francisco\"}"}}}"# ,
329+ #"data: {"type":"response.completed","finish_reason":"tool_calls"}"# ,
330+ ]
331+ } else {
332+ events = [
333+ #"data: {"type":"response.output_text.delta","delta":"Tool says: Sunny."}"# ,
334+ #"data: {"type":"response.completed","finish_reason":"stop"}"# ,
335+ ]
336+ }
337+
338+ let payload = events. joined ( separator: " \n \n " ) + " \n \n "
339+ return ( response, [ payload. data ( using: . utf8) !] )
340+ }
341+ defer { MockOpenAIEventStreamURLProtocol . Handler. clear ( ) }
342+
343+ let config = URLSessionConfiguration . ephemeral
344+ config. protocolClasses = [ MockOpenAIEventStreamURLProtocol . self]
345+
346+ let model = OpenAILanguageModel (
347+ baseURL: baseURL,
348+ apiKey: " test-key " ,
349+ model: " gpt-test " ,
350+ apiVariant: . responses,
351+ session: URLSession ( configuration: config)
352+ )
353+ let session = LanguageModelSession ( model: model, tools: [ WeatherTool ( ) ] )
354+
355+ var snapshots : [ LanguageModelSession . ResponseStream < String > . Snapshot ] = [ ]
356+ for try await snapshot in session. streamResponse ( to: " What's the weather? " ) {
357+ snapshots. append ( snapshot)
358+ }
359+
360+ #expect( responsesCallCount >= 2 )
361+ }
362+
363+ @Test ( . disabled( " Streaming mock under construction " ) ) func chatCompletionsStreamToolCallExecution( ) async throws {
364+ var chatCallCount = 0
365+ URLProtocol . registerClass ( MockOpenAIEventStreamURLProtocol . self)
366+ MockOpenAIEventStreamURLProtocol . Handler. set { request in
367+ defer { chatCallCount += 1 }
368+ let response = HTTPURLResponse (
369+ url: request. url!,
370+ statusCode: 200 ,
371+ httpVersion: nil ,
372+ headerFields: [ " Content-Type " : " text/event-stream " ]
373+ ) !
374+
375+ let events : [ String ]
376+ if chatCallCount == 0 {
377+ events = [
378+ #"data: {"id":"evt_1","choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"getWeather","arguments":""}}]},"finish_reason":null}]}"# ,
379+ #"data: {"id":"evt_1","choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","function":{"arguments":"{\"city\":\"Paris\"}"}}]},"finish_reason":null}]}"# ,
380+ #"data: {"id":"evt_1","choices":[{"delta":{},"finish_reason":"tool_calls"}]}"# ,
381+ ]
382+ } else {
383+ events = [
384+ #"data: {"id":"evt_1","choices":[{"delta":{"content":"Tool says Paris is sunny."},"finish_reason":null}]}"# ,
385+ #"data: {"id":"evt_1","choices":[{"delta":{},"finish_reason":"stop"}]}"# ,
386+ ]
387+ }
388+
389+ let payload = events. joined ( separator: " \n \n " ) + " \n \n "
390+ return ( response, [ payload. data ( using: . utf8) !] )
391+ }
392+ defer { MockOpenAIEventStreamURLProtocol . Handler. clear ( ) }
393+
394+ let config = URLSessionConfiguration . ephemeral
395+ config. protocolClasses = [ MockOpenAIEventStreamURLProtocol . self]
396+
397+ let model = OpenAILanguageModel (
398+ baseURL: baseURL,
399+ apiKey: " test-key " ,
400+ model: " gpt-test " ,
401+ apiVariant: . chatCompletions,
402+ session: URLSession ( configuration: config)
403+ )
404+ let session = LanguageModelSession ( model: model, tools: [ WeatherTool ( ) ] )
405+
406+ var snapshots : [ LanguageModelSession . ResponseStream < String > . Snapshot ] = [ ]
407+ for try await snapshot in session. streamResponse ( to: " What's the weather? " ) {
408+ snapshots. append ( snapshot)
409+ }
410+
411+ #expect( chatCallCount >= 2 )
412+ }
413+ }
414+
415+ private final class MockOpenAIEventStreamURLProtocol : URLProtocol {
416+ enum Handler {
417+ nonisolated ( unsafe) private static var handler : ( ( URLRequest ) -> ( HTTPURLResponse , [ Data ] ) ) ?
418+ private static let lock = NSLock ( )
419+
420+ static func set( _ handler: @escaping ( URLRequest ) -> ( HTTPURLResponse , [ Data ] ) ) {
421+ lock. lock ( )
422+ self . handler = handler
423+ lock. unlock ( )
424+ }
425+
426+ static func clear( ) {
427+ lock. lock ( )
428+ handler = nil
429+ lock. unlock ( )
430+ }
431+
432+ static func handle( _ request: URLRequest ) -> ( HTTPURLResponse , [ Data ] ) ? {
433+ lock. lock ( )
434+ let result = handler ? ( request)
435+ lock. unlock ( )
436+ return result
437+ }
438+ }
439+
440+ override class func canInit( with request: URLRequest ) -> Bool {
441+ true
442+ }
443+
444+ override class func canInit( with task: URLSessionTask ) -> Bool {
445+ if let request = task. currentRequest {
446+ return canInit ( with: request)
447+ }
448+ return false
449+ }
450+
451+ override class func canonicalRequest( for request: URLRequest ) -> URLRequest {
452+ request
453+ }
454+
455+ override func startLoading( ) {
456+ guard let handler = Handler . handle ( request) else {
457+ client? . urlProtocol ( self , didFailWithError: URLError ( . badServerResponse) )
458+ return
459+ }
460+
461+ let ( response, dataChunks) = handler
462+ client? . urlProtocol ( self , didReceive: response, cacheStoragePolicy: . notAllowed)
463+ for chunk in dataChunks {
464+ client? . urlProtocol ( self , didLoad: chunk)
465+ }
466+ client? . urlProtocolDidFinishLoading ( self )
467+ }
468+
469+ override func stopLoading( ) { }
470+ }
0 commit comments