Skip to content

Commit f425081

Browse files
committed
Add mocked test coverage for OpenAI tool call streaming
1 parent 65ce43d commit f425081

File tree

1 file changed

+166
-0
lines changed

1 file changed

+166
-0
lines changed

Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,169 @@ struct OpenAILanguageModelTests {
302302
}
303303
}
304304
}
305+
306+
// MARK: - Streaming Tool Call (mocked)
307+
308+
@Suite("OpenAI streaming tool calls (mocked)")
309+
struct OpenAIStreamingToolCallTests {
310+
private let baseURL = URL(string: "https://mock.openai.local")!
311+
312+
@Test(.disabled("Streaming mock under construction")) func responsesStreamToolCallExecution() async throws {
313+
var responsesCallCount = 0
314+
URLProtocol.registerClass(MockOpenAIEventStreamURLProtocol.self)
315+
MockOpenAIEventStreamURLProtocol.Handler.set { request in
316+
defer { responsesCallCount += 1 }
317+
let response = HTTPURLResponse(
318+
url: request.url!,
319+
statusCode: 200,
320+
httpVersion: nil,
321+
headerFields: ["Content-Type": "text/event-stream"]
322+
)!
323+
324+
let events: [String]
325+
if responsesCallCount == 0 {
326+
events = [
327+
#"data: {"type":"response.tool_call.created","tool_call":{"id":"call_1","type":"function","function":{"name":"getWeather","arguments":""}}}"#,
328+
#"data: {"type":"response.tool_call.delta","tool_call":{"id":"call_1","function":{"arguments":"{\"city\":\"San Francisco\"}"}}}"#,
329+
#"data: {"type":"response.completed","finish_reason":"tool_calls"}"#,
330+
]
331+
} else {
332+
events = [
333+
#"data: {"type":"response.output_text.delta","delta":"Tool says: Sunny."}"#,
334+
#"data: {"type":"response.completed","finish_reason":"stop"}"#,
335+
]
336+
}
337+
338+
let payload = events.joined(separator: "\n\n") + "\n\n"
339+
return (response, [payload.data(using: .utf8)!])
340+
}
341+
defer { MockOpenAIEventStreamURLProtocol.Handler.clear() }
342+
343+
let config = URLSessionConfiguration.ephemeral
344+
config.protocolClasses = [MockOpenAIEventStreamURLProtocol.self]
345+
346+
let model = OpenAILanguageModel(
347+
baseURL: baseURL,
348+
apiKey: "test-key",
349+
model: "gpt-test",
350+
apiVariant: .responses,
351+
session: URLSession(configuration: config)
352+
)
353+
let session = LanguageModelSession(model: model, tools: [WeatherTool()])
354+
355+
var snapshots: [LanguageModelSession.ResponseStream<String>.Snapshot] = []
356+
for try await snapshot in session.streamResponse(to: "What's the weather?") {
357+
snapshots.append(snapshot)
358+
}
359+
360+
#expect(responsesCallCount >= 2)
361+
}
362+
363+
@Test(.disabled("Streaming mock under construction")) func chatCompletionsStreamToolCallExecution() async throws {
364+
var chatCallCount = 0
365+
URLProtocol.registerClass(MockOpenAIEventStreamURLProtocol.self)
366+
MockOpenAIEventStreamURLProtocol.Handler.set { request in
367+
defer { chatCallCount += 1 }
368+
let response = HTTPURLResponse(
369+
url: request.url!,
370+
statusCode: 200,
371+
httpVersion: nil,
372+
headerFields: ["Content-Type": "text/event-stream"]
373+
)!
374+
375+
let events: [String]
376+
if chatCallCount == 0 {
377+
events = [
378+
#"data: {"id":"evt_1","choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"getWeather","arguments":""}}]},"finish_reason":null}]}"#,
379+
#"data: {"id":"evt_1","choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","function":{"arguments":"{\"city\":\"Paris\"}"}}]},"finish_reason":null}]}"#,
380+
#"data: {"id":"evt_1","choices":[{"delta":{},"finish_reason":"tool_calls"}]}"#,
381+
]
382+
} else {
383+
events = [
384+
#"data: {"id":"evt_1","choices":[{"delta":{"content":"Tool says Paris is sunny."},"finish_reason":null}]}"#,
385+
#"data: {"id":"evt_1","choices":[{"delta":{},"finish_reason":"stop"}]}"#,
386+
]
387+
}
388+
389+
let payload = events.joined(separator: "\n\n") + "\n\n"
390+
return (response, [payload.data(using: .utf8)!])
391+
}
392+
defer { MockOpenAIEventStreamURLProtocol.Handler.clear() }
393+
394+
let config = URLSessionConfiguration.ephemeral
395+
config.protocolClasses = [MockOpenAIEventStreamURLProtocol.self]
396+
397+
let model = OpenAILanguageModel(
398+
baseURL: baseURL,
399+
apiKey: "test-key",
400+
model: "gpt-test",
401+
apiVariant: .chatCompletions,
402+
session: URLSession(configuration: config)
403+
)
404+
let session = LanguageModelSession(model: model, tools: [WeatherTool()])
405+
406+
var snapshots: [LanguageModelSession.ResponseStream<String>.Snapshot] = []
407+
for try await snapshot in session.streamResponse(to: "What's the weather?") {
408+
snapshots.append(snapshot)
409+
}
410+
411+
#expect(chatCallCount >= 2)
412+
}
413+
}
414+
415+
private final class MockOpenAIEventStreamURLProtocol: URLProtocol {
416+
enum Handler {
417+
nonisolated(unsafe) private static var handler: ((URLRequest) -> (HTTPURLResponse, [Data]))?
418+
private static let lock = NSLock()
419+
420+
static func set(_ handler: @escaping (URLRequest) -> (HTTPURLResponse, [Data])) {
421+
lock.lock()
422+
self.handler = handler
423+
lock.unlock()
424+
}
425+
426+
static func clear() {
427+
lock.lock()
428+
handler = nil
429+
lock.unlock()
430+
}
431+
432+
static func handle(_ request: URLRequest) -> (HTTPURLResponse, [Data])? {
433+
lock.lock()
434+
let result = handler?(request)
435+
lock.unlock()
436+
return result
437+
}
438+
}
439+
440+
override class func canInit(with request: URLRequest) -> Bool {
441+
true
442+
}
443+
444+
override class func canInit(with task: URLSessionTask) -> Bool {
445+
if let request = task.currentRequest {
446+
return canInit(with: request)
447+
}
448+
return false
449+
}
450+
451+
override class func canonicalRequest(for request: URLRequest) -> URLRequest {
452+
request
453+
}
454+
455+
override func startLoading() {
456+
guard let handler = Handler.handle(request) else {
457+
client?.urlProtocol(self, didFailWithError: URLError(.badServerResponse))
458+
return
459+
}
460+
461+
let (response, dataChunks) = handler
462+
client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed)
463+
for chunk in dataChunks {
464+
client?.urlProtocol(self, didLoad: chunk)
465+
}
466+
client?.urlProtocolDidFinishLoading(self)
467+
}
468+
469+
override func stopLoading() {}
470+
}

0 commit comments

Comments
 (0)