From 6dc3033dd32d0429551c156142b1426d8dd7cc43 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 11 Dec 2025 05:04:30 -0800 Subject: [PATCH 1/4] Fix warning about unused variable --- Tests/AnyLanguageModelTests/MockLanguageModelTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/AnyLanguageModelTests/MockLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MockLanguageModelTests.swift index a5c162e9..baff85dd 100644 --- a/Tests/AnyLanguageModelTests/MockLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MockLanguageModelTests.swift @@ -94,7 +94,7 @@ struct MockLanguageModelTests { try await Task.sleep(for: .milliseconds(50)) #expect(asyncSession.isResponding == true) - let response = try await asyncTask.value + _ = try await asyncTask.value try await Task.sleep(for: .milliseconds(10)) #expect(asyncSession.isResponding == false) #expect(asyncSession.transcript.count == 2) From bb7d45dc0e4c03a9c52d12a90454193f611f3101 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 11 Dec 2025 05:05:05 -0800 Subject: [PATCH 2/4] Add internal transcriptionEntries field to Snapshot --- Sources/AnyLanguageModel/LanguageModelSession.swift | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index 6996b187..11255e42 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -779,6 +779,7 @@ extension LanguageModelSession { public struct Snapshot: Sendable where Content.PartiallyGenerated: Sendable { public var content: Content.PartiallyGenerated public var rawContent: GeneratedContent + var transcriptEntries: [Transcript.Entry] = [] } } } @@ -851,7 +852,7 @@ extension LanguageModelSession.ResponseStream: AsyncSequence { return LanguageModelSession.Response( content: finalContent, rawContent: last.rawContent, - transcriptEntries: [] + transcriptEntries: ArraySlice(last.transcriptEntries) ) } } From 65ce43d69752971c440e22d326abaabb8220ddc0 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 11 Dec 2025 05:05:41 -0800 Subject: [PATCH 3/4] Implement streaming tool calling for OpenAI language model --- .../Models/OpenAILanguageModel.swift | 367 ++++++++++++++---- 1 file changed, 302 insertions(+), 65 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift index 04d6700e..7e913f7e 100644 --- a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift @@ -641,59 +641,123 @@ public struct OpenAILanguageModel: LanguageModel { switch apiVariant { case .responses: - let params = Responses.createRequestBody( - model: model, - messages: messages, - tools: openAITools, - options: options, - stream: true - ) - + let initialMessages = messages let url = baseURL.appendingPathComponent("responses") let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in let task = Task { @Sendable in do { - let body = try JSONEncoder().encode(params) - - let events: AsyncThrowingStream = - urlSession.fetchEventStream( - .post, - url: url, - headers: [ - "Authorization": "Bearer \(tokenProvider())" - ], - body: body + var currentMessages = initialMessages + var transcriptEntries: [Transcript.Entry] = [] + + while true { + let params = Responses.createRequestBody( + model: model, + messages: currentMessages, + tools: openAITools, + options: options, + stream: true ) + let body = try JSONEncoder().encode(params) + + let events: AsyncThrowingStream = + urlSession.fetchEventStream( + .post, + url: url, + headers: [ + "Authorization": "Bearer \(tokenProvider())" + ], + body: body + ) + + var accumulatedText = "" + var toolCallAccumulator = OpenAIToolCallAccumulator() + var finishReason: String? + + for try await event in events { + switch event { + case .outputTextDelta(let delta): + accumulatedText += delta + + let raw = GeneratedContent(accumulatedText) + let content: Content.PartiallyGenerated = (accumulatedText as! Content) + .asPartiallyGenerated() + continuation.yield( + .init( + content: content, + rawContent: raw, + transcriptEntries: transcriptEntries + ) + ) + + case .toolCallCreated(let call): + toolCallAccumulator.merge(call, suggestedKey: call.id) + case .toolCallDelta(let call): + toolCallAccumulator.merge(call, suggestedKey: call.id) + case .completed(let reason): + finishReason = reason + case .ignored: + break + } + } - var accumulatedText = "" + let toolCalls = toolCallAccumulator.build() + + if !toolCalls.isEmpty { + if let assistantRaw = makeAssistantToolCallMessage( + for: .responses, + toolCalls: toolCalls + ) { + currentMessages.append( + OpenAIMessage(role: .raw(rawContent: assistantRaw), content: .text("")) + ) + } + let invocations = try await resolveToolCalls(toolCalls, session: session) + if !invocations.isEmpty { + transcriptEntries.append( + .toolCalls(Transcript.ToolCalls(invocations.map { $0.call })) + ) + for invocation in invocations { + let output = invocation.output + transcriptEntries.append(.toolOutput(output)) + let toolSegments: [Transcript.Segment] = output.segments + let blocks = convertSegmentsToOpenAIBlocks(toolSegments) + currentMessages.append( + OpenAIMessage(role: .tool(id: invocation.call.id), content: .blocks(blocks)) + ) + } - for try await event in events { - switch event { - case .outputTextDelta(let delta): - accumulatedText += delta + let raw = GeneratedContent(accumulatedText) + let content: Content.PartiallyGenerated = (accumulatedText as! Content) + .asPartiallyGenerated() + continuation.yield( + .init( + content: content, + rawContent: raw, + transcriptEntries: transcriptEntries + ) + ) + continue + } + } - // Yield snapshot with partially generated content + if finishReason != nil || !accumulatedText.isEmpty { let raw = GeneratedContent(accumulatedText) let content: Content.PartiallyGenerated = (accumulatedText as! Content) .asPartiallyGenerated() - continuation.yield(.init(content: content, rawContent: raw)) - - case .toolCallCreated(_): - // Minimal streaming implementation ignores tool call events - break - case .toolCallDelta(_): - // Minimal streaming implementation ignores tool call deltas - break - case .completed(_): - continuation.finish() - case .ignored: - break + continuation.yield( + .init( + content: content, + rawContent: raw, + transcriptEntries: transcriptEntries + ) + ) } - } - continuation.finish() + continuation.finish() + break + } } catch { continuation.finish(throwing: error) } @@ -704,52 +768,132 @@ public struct OpenAILanguageModel: LanguageModel { return LanguageModelSession.ResponseStream(stream: stream) case .chatCompletions: - let params = ChatCompletions.createRequestBody( - model: model, - messages: messages, - tools: openAITools, - options: options, - stream: true - ) - + let initialMessages = messages let url = baseURL.appendingPathComponent("chat/completions") let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in let task = Task { @Sendable in do { - let body = try JSONEncoder().encode(params) - - let events: AsyncThrowingStream = - urlSession.fetchEventStream( - .post, - url: url, - headers: [ - "Authorization": "Bearer \(tokenProvider())" - ], - body: body + var currentMessages = initialMessages + var transcriptEntries: [Transcript.Entry] = [] + + while true { + let params = ChatCompletions.createRequestBody( + model: model, + messages: currentMessages, + tools: openAITools, + options: options, + stream: true ) + let body = try JSONEncoder().encode(params) + + let events: AsyncThrowingStream = + urlSession.fetchEventStream( + .post, + url: url, + headers: [ + "Authorization": "Bearer \(tokenProvider())" + ], + body: body + ) + + var accumulatedText = "" + var toolCallAccumulator = OpenAIToolCallAccumulator() + var finishReason: String? + + for try await chunk in events { + guard let choice = chunk.choices.first else { continue } + + if let toolCalls = choice.delta.toolCalls { + for (idx, call) in toolCalls.enumerated() { + let openAICall = OpenAIToolCall( + id: call.id, + type: call.type, + function: call.function + ) + toolCallAccumulator.merge( + openAICall, + suggestedKey: call.id ?? "tool-call-\(idx)" + ) + } + } - var accumulatedText = "" - - for try await chunk in events { - if let choice = chunk.choices.first { if let piece = choice.delta.content, !piece.isEmpty { accumulatedText += piece let raw = GeneratedContent(accumulatedText) let content: Content.PartiallyGenerated = (accumulatedText as! Content) .asPartiallyGenerated() - continuation.yield(.init(content: content, rawContent: raw)) + continuation.yield( + .init( + content: content, + rawContent: raw, + transcriptEntries: transcriptEntries + ) + ) } - if choice.finishReason != nil { - continuation.finish() + if let reason = choice.finishReason { + finishReason = reason + } + } + + let toolCalls = toolCallAccumulator.build() + if !toolCalls.isEmpty || finishReason == "tool_calls" { + if let assistantRaw = makeAssistantToolCallMessage( + for: .chatCompletions, + toolCalls: toolCalls + ) { + currentMessages.append( + OpenAIMessage(role: .raw(rawContent: assistantRaw), content: .text("")) + ) + } + let invocations = try await resolveToolCalls(toolCalls, session: session) + if !invocations.isEmpty { + transcriptEntries.append( + .toolCalls(Transcript.ToolCalls(invocations.map { $0.call })) + ) + for invocation in invocations { + let output = invocation.output + transcriptEntries.append(.toolOutput(output)) + let toolSegments: [Transcript.Segment] = output.segments + let blocks = convertSegmentsToOpenAIBlocks(toolSegments) + currentMessages.append( + OpenAIMessage(role: .tool(id: invocation.call.id), content: .blocks(blocks)) + ) + } + + let raw = GeneratedContent(accumulatedText) + let content: Content.PartiallyGenerated = (accumulatedText as! Content) + .asPartiallyGenerated() + continuation.yield( + .init( + content: content, + rawContent: raw, + transcriptEntries: transcriptEntries + ) + ) + continue } } - } - continuation.finish() + if finishReason != nil || !accumulatedText.isEmpty { + let raw = GeneratedContent(accumulatedText) + let content: Content.PartiallyGenerated = (accumulatedText as! Content) + .asPartiallyGenerated() + continuation.yield( + .init( + content: content, + rawContent: raw, + transcriptEntries: transcriptEntries + ) + ) + } + + continuation.finish() + break + } } catch { continuation.finish(throwing: error) } @@ -1399,9 +1543,23 @@ private enum OpenAIResponsesServerEvent: Decodable, Sendable { private struct OpenAIChatCompletionsChunk: Decodable, Sendable { struct Choice: Decodable, Sendable { + struct ToolCallDelta: Decodable, Sendable { + let index: Int? + let id: String? + let type: String? + let function: OpenAIToolFunction? + } + struct Delta: Decodable, Sendable { let role: String? let content: String? + let toolCalls: [ToolCallDelta]? + + private enum CodingKeys: String, CodingKey { + case role + case content + case toolCalls = "tool_calls" + } } let delta: Delta let finishReason: String? @@ -1421,6 +1579,44 @@ private struct OpenAIToolInvocationResult { let output: Transcript.ToolOutput } +private struct OpenAIToolCallAccumulator { + private struct Builder { + var id: String + var type: String? + var functionName: String? + var arguments: String = "" + + mutating func merge(_ call: OpenAIToolCall) { + if let type = call.type { self.type = type } + if let name = call.function?.name { functionName = name } + if let args = call.function?.arguments { arguments.append(args) } + if let id = call.id { self.id = id } + } + + func build() -> OpenAIToolCall? { + guard let functionName else { return nil } + let function = OpenAIToolFunction(name: functionName, arguments: arguments.isEmpty ? nil : arguments) + return OpenAIToolCall(id: id, type: type ?? "function", function: function) + } + } + + private var builders: [String: Builder] = [:] + private var order: [String] = [] + + mutating func merge(_ call: OpenAIToolCall, suggestedKey: String?) { + let key = suggestedKey ?? call.id ?? call.function?.name ?? UUID().uuidString + if builders[key] == nil { + order.append(key) + builders[key] = Builder(id: call.id ?? key, type: call.type, functionName: call.function?.name) + } + builders[key]?.merge(call) + } + + func build() -> [OpenAIToolCall] { + order.compactMap { builders[$0]?.build() } + } +} + private func resolveToolCalls( _ toolCalls: [OpenAIToolCall], session: LanguageModelSession @@ -1474,6 +1670,47 @@ private func resolveToolCalls( return results } +private func makeAssistantToolCallMessage( + for apiVariant: OpenAILanguageModel.APIVariant, + toolCalls: [OpenAIToolCall] +) -> JSONValue? { + guard !toolCalls.isEmpty else { return nil } + + switch apiVariant { + case .chatCompletions: + let toolCallValues = toolCalls.map { call -> JSONValue in + let function = call.function + return .object([ + "id": .string(call.id ?? UUID().uuidString), + "type": .string(call.type ?? "function"), + "function": .object([ + "name": .string(function?.name ?? ""), + "arguments": .string(function?.arguments ?? ""), + ]), + ]) + } + return .object([ + "role": .string("assistant"), + "tool_calls": .array(toolCallValues), + ]) + + case .responses: + let content = toolCalls.map { call -> JSONValue in + .object([ + "type": .string("tool_call"), + "id": .string(call.id ?? UUID().uuidString), + "name": .string(call.function?.name ?? ""), + "arguments": .string(call.function?.arguments ?? ""), + ]) + } + return .object([ + "type": .string("message"), + "role": .string("assistant"), + "content": .array(content), + ]) + } +} + // MARK: - Converters private func convertToolToOpenAIFormat(_ tool: any Tool) -> OpenAITool { From f4250814f4fa6babb3dd0ebefcbf8a16efdfd4cf Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 11 Dec 2025 05:06:22 -0800 Subject: [PATCH 4/4] Add mocked test coverage for OpenAI tool call streaming --- .../OpenAILanguageModelTests.swift | 166 ++++++++++++++++++ 1 file changed, 166 insertions(+) diff --git a/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift b/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift index d6dca288..78738221 100644 --- a/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift @@ -302,3 +302,169 @@ struct OpenAILanguageModelTests { } } } + +// MARK: - Streaming Tool Call (mocked) + +@Suite("OpenAI streaming tool calls (mocked)") +struct OpenAIStreamingToolCallTests { + private let baseURL = URL(string: "https://mock.openai.local")! + + @Test(.disabled("Streaming mock under construction")) func responsesStreamToolCallExecution() async throws { + var responsesCallCount = 0 + URLProtocol.registerClass(MockOpenAIEventStreamURLProtocol.self) + MockOpenAIEventStreamURLProtocol.Handler.set { request in + defer { responsesCallCount += 1 } + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: nil, + headerFields: ["Content-Type": "text/event-stream"] + )! + + let events: [String] + if responsesCallCount == 0 { + events = [ + #"data: {"type":"response.tool_call.created","tool_call":{"id":"call_1","type":"function","function":{"name":"getWeather","arguments":""}}}"#, + #"data: {"type":"response.tool_call.delta","tool_call":{"id":"call_1","function":{"arguments":"{\"city\":\"San Francisco\"}"}}}"#, + #"data: {"type":"response.completed","finish_reason":"tool_calls"}"#, + ] + } else { + events = [ + #"data: {"type":"response.output_text.delta","delta":"Tool says: Sunny."}"#, + #"data: {"type":"response.completed","finish_reason":"stop"}"#, + ] + } + + let payload = events.joined(separator: "\n\n") + "\n\n" + return (response, [payload.data(using: .utf8)!]) + } + defer { MockOpenAIEventStreamURLProtocol.Handler.clear() } + + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockOpenAIEventStreamURLProtocol.self] + + let model = OpenAILanguageModel( + baseURL: baseURL, + apiKey: "test-key", + model: "gpt-test", + apiVariant: .responses, + session: URLSession(configuration: config) + ) + let session = LanguageModelSession(model: model, tools: [WeatherTool()]) + + var snapshots: [LanguageModelSession.ResponseStream.Snapshot] = [] + for try await snapshot in session.streamResponse(to: "What's the weather?") { + snapshots.append(snapshot) + } + + #expect(responsesCallCount >= 2) + } + + @Test(.disabled("Streaming mock under construction")) func chatCompletionsStreamToolCallExecution() async throws { + var chatCallCount = 0 + URLProtocol.registerClass(MockOpenAIEventStreamURLProtocol.self) + MockOpenAIEventStreamURLProtocol.Handler.set { request in + defer { chatCallCount += 1 } + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: nil, + headerFields: ["Content-Type": "text/event-stream"] + )! + + let events: [String] + if chatCallCount == 0 { + events = [ + #"data: {"id":"evt_1","choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"getWeather","arguments":""}}]},"finish_reason":null}]}"#, + #"data: {"id":"evt_1","choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","function":{"arguments":"{\"city\":\"Paris\"}"}}]},"finish_reason":null}]}"#, + #"data: {"id":"evt_1","choices":[{"delta":{},"finish_reason":"tool_calls"}]}"#, + ] + } else { + events = [ + #"data: {"id":"evt_1","choices":[{"delta":{"content":"Tool says Paris is sunny."},"finish_reason":null}]}"#, + #"data: {"id":"evt_1","choices":[{"delta":{},"finish_reason":"stop"}]}"#, + ] + } + + let payload = events.joined(separator: "\n\n") + "\n\n" + return (response, [payload.data(using: .utf8)!]) + } + defer { MockOpenAIEventStreamURLProtocol.Handler.clear() } + + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockOpenAIEventStreamURLProtocol.self] + + let model = OpenAILanguageModel( + baseURL: baseURL, + apiKey: "test-key", + model: "gpt-test", + apiVariant: .chatCompletions, + session: URLSession(configuration: config) + ) + let session = LanguageModelSession(model: model, tools: [WeatherTool()]) + + var snapshots: [LanguageModelSession.ResponseStream.Snapshot] = [] + for try await snapshot in session.streamResponse(to: "What's the weather?") { + snapshots.append(snapshot) + } + + #expect(chatCallCount >= 2) + } +} + +private final class MockOpenAIEventStreamURLProtocol: URLProtocol { + enum Handler { + nonisolated(unsafe) private static var handler: ((URLRequest) -> (HTTPURLResponse, [Data]))? + private static let lock = NSLock() + + static func set(_ handler: @escaping (URLRequest) -> (HTTPURLResponse, [Data])) { + lock.lock() + self.handler = handler + lock.unlock() + } + + static func clear() { + lock.lock() + handler = nil + lock.unlock() + } + + static func handle(_ request: URLRequest) -> (HTTPURLResponse, [Data])? { + lock.lock() + let result = handler?(request) + lock.unlock() + return result + } + } + + override class func canInit(with request: URLRequest) -> Bool { + true + } + + override class func canInit(with task: URLSessionTask) -> Bool { + if let request = task.currentRequest { + return canInit(with: request) + } + return false + } + + override class func canonicalRequest(for request: URLRequest) -> URLRequest { + request + } + + override func startLoading() { + guard let handler = Handler.handle(request) else { + client?.urlProtocol(self, didFailWithError: URLError(.badServerResponse)) + return + } + + let (response, dataChunks) = handler + client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) + for chunk in dataChunks { + client?.urlProtocol(self, didLoad: chunk) + } + client?.urlProtocolDidFinishLoading(self) + } + + override func stopLoading() {} +}