diff --git a/FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift b/FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift index dedd1223adb..59e1581a638 100644 --- a/FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift +++ b/FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift @@ -337,6 +337,53 @@ final class GenerativeModelGoogleAITests: XCTestCase { XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 160) } + func testGenerateContent_success_urlContext() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-url-context", + withExtension: "json", + subdirectory: googleAISubdirectory + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata) + XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1) + let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first) + let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL) + XCTAssertEqual( + retrievedURL, + URL(string: "https://berkshirehathaway.com") + ) + XCTAssertEqual(urlMetadata.retrievalStatus, .success) + let usageMetadata = try XCTUnwrap(response.usageMetadata) + XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 424) + } + + func testGenerateContent_success_urlContext_mixedValidity() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-url-context-mixed-validity", + withExtension: "json", + subdirectory: googleAISubdirectory + ) + + let response = try await model.generateContent(testPrompt) + + let candidate = try XCTUnwrap(response.candidates.first) + let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata) + XCTAssertEqual(urlContextMetadata.urlMetadata.count, 3) + + let paywallURLMetadata = urlContextMetadata.urlMetadata[0] + XCTAssertEqual(paywallURLMetadata.retrievalStatus, .error) + + let successURLMetadata = urlContextMetadata.urlMetadata[1] + XCTAssertEqual(successURLMetadata.retrievalStatus, .success) + + let errorURLMetadata = urlContextMetadata.urlMetadata[2] + XCTAssertEqual(errorURLMetadata.retrievalStatus, .error) + } + func testGenerateContent_failure_invalidAPIKey() async throws { let expectedStatusCode = 400 MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( @@ -644,4 +691,27 @@ final class GenerativeModelGoogleAITests: XCTestCase { let lastResponse = try XCTUnwrap(responses.last) XCTAssertEqual(lastResponse.text, "text8") } + + func testGenerateContentStream_success_urlContext() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "streaming-success-url-context", + withExtension: "txt", + subdirectory: googleAISubdirectory + ) + + var responses = [GenerateContentResponse]() + let stream = try model.generateContentStream(testPrompt) + for try await response in stream { + responses.append(response) + } + + let firstResponse = try XCTUnwrap(responses.first) + let candidate = try XCTUnwrap(firstResponse.candidates.first) + let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata) + XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1) + let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first) + let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL) + XCTAssertEqual(retrievedURL, URL(string: "https://google.com")) + XCTAssertEqual(urlMetadata.retrievalStatus, .success) + } } diff --git a/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift b/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift index 304545ca5f9..1d2498f07e5 100644 --- a/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift +++ b/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift @@ -491,6 +491,71 @@ final class GenerativeModelVertexAITests: XCTestCase { XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 371) } + func testGenerateContent_success_urlContext() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-url-context", + withExtension: "json", + subdirectory: vertexSubdirectory + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata) + XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1) + let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first) + let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL) + XCTAssertEqual( + retrievedURL, + URL(string: "https://berkshirehathaway.com") + ) + XCTAssertEqual(urlMetadata.retrievalStatus, .success) + let usageMetadata = try XCTUnwrap(response.usageMetadata) + XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 34) + XCTAssertEqual(usageMetadata.thoughtsTokenCount, 36) + } + + func testGenerateContent_success_urlContext_mixedValidity() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-url-context-mixed-validity", + withExtension: "json", + subdirectory: vertexSubdirectory + ) + + let response = try await model.generateContent(testPrompt) + + let candidate = try XCTUnwrap(response.candidates.first) + let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata) + XCTAssertEqual(urlContextMetadata.urlMetadata.count, 3) + + let paywallURLMetadata = urlContextMetadata.urlMetadata[0] + XCTAssertEqual(paywallURLMetadata.retrievalStatus, .error) + + let successURLMetadata = urlContextMetadata.urlMetadata[1] + XCTAssertEqual(successURLMetadata.retrievalStatus, .success) + + let errorURLMetadata = urlContextMetadata.urlMetadata[2] + XCTAssertEqual(errorURLMetadata.retrievalStatus, .error) + } + + func testGenerateContent_success_urlContext_retrievedURLPresentOnErrorStatus() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-url-context-missing-retrievedurl", + withExtension: "json", + subdirectory: vertexSubdirectory + ) + + let response = try await model.generateContent(testPrompt) + + let candidate = try XCTUnwrap(response.candidates.first) + let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata) + let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first) + let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL) + XCTAssertEqual(retrievedURL.absoluteString, "https://example.com/8") + XCTAssertEqual(urlMetadata.retrievalStatus, .error) + } + func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws { MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( forResource: "unary-success-image-invalid-safety-ratings", @@ -1720,6 +1785,29 @@ final class GenerativeModelVertexAITests: XCTestCase { XCTAssertEqual(responses, 1) } + func testGenerateContentStream_success_urlContext() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "streaming-success-url-context", + withExtension: "txt", + subdirectory: vertexSubdirectory + ) + + var responses = [GenerateContentResponse]() + let stream = try model.generateContentStream(testPrompt) + for try await response in stream { + responses.append(response) + } + + let firstResponse = try XCTUnwrap(responses.first) + let candidate = try XCTUnwrap(firstResponse.candidates.first) + let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata) + XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1) + let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first) + let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL) + XCTAssertEqual(retrievedURL, URL(string: "https://google.com")) + XCTAssertEqual(urlMetadata.retrievalStatus, .success) + } + // MARK: - Count Tokens func testCountTokens_succeeds() async throws { diff --git a/FirebaseAI/Tests/Unit/MockURLProtocol.swift b/FirebaseAI/Tests/Unit/MockURLProtocol.swift index 5385b164015..6db227d5cfb 100644 --- a/FirebaseAI/Tests/Unit/MockURLProtocol.swift +++ b/FirebaseAI/Tests/Unit/MockURLProtocol.swift @@ -21,6 +21,7 @@ class MockURLProtocol: URLProtocol, @unchecked Sendable { URLResponse, AsyncLineSequence? ))? + override class func canInit(with request: URLRequest) -> Bool { #if os(watchOS) print("MockURLProtocol cannot be used on watchOS.") @@ -33,13 +34,14 @@ class MockURLProtocol: URLProtocol, @unchecked Sendable { override class func canonicalRequest(for request: URLRequest) -> URLRequest { return request } override func startLoading() { - guard let requestHandler = MockURLProtocol.requestHandler else { - fatalError("`requestHandler` is nil.") - } guard let client = client else { fatalError("`client` is nil.") } + guard let requestHandler = MockURLProtocol.requestHandler else { + fatalError("No request handler set.") + } + Task { let (response, stream) = try requestHandler(self.request) client.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) diff --git a/FirebaseAI/Tests/Unit/Types/GenerateContentResponseTests.swift b/FirebaseAI/Tests/Unit/Types/GenerateContentResponseTests.swift index a53d215359f..276308f63aa 100644 --- a/FirebaseAI/Tests/Unit/Types/GenerateContentResponseTests.swift +++ b/FirebaseAI/Tests/Unit/Types/GenerateContentResponseTests.swift @@ -17,6 +17,8 @@ import XCTest @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) final class GenerateContentResponseTests: XCTestCase { + let jsonDecoder = JSONDecoder() + // MARK: - GenerateContentResponse Computed Properties func testGenerateContentResponse_inlineDataParts_success() throws { @@ -106,4 +108,53 @@ final class GenerateContentResponseTests: XCTestCase { "functionCalls should be empty when there are no candidates." ) } + + // MARK: - Decoding Tests + + func testDecodeCandidate_emptyURLMetadata_urlContextMetadataIsNil() throws { + let json = """ + { + "content": { "role": "model", "parts": [ { "text": "Some text." } ] }, + "finishReason": "STOP", + "urlContextMetadata": { "urlMetadata": [] } + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let candidate = try jsonDecoder.decode(Candidate.self, from: jsonData) + + XCTAssertNil( + candidate.urlContextMetadata, + "urlContextMetadata should be nil if the `urlMetadata` array is empty in the candidate." + ) + XCTAssertEqual(candidate.content.role, "model") + let part = try XCTUnwrap(candidate.content.parts.first) + let textPart = try XCTUnwrap(part as? TextPart) + XCTAssertEqual(textPart.text, "Some text.") + XCTAssertFalse(textPart.isThought) + XCTAssertEqual(candidate.finishReason, .stop) + } + + func testDecodeCandidate_missingURLMetadata_urlContextMetadataIsNil() throws { + let json = """ + { + "content": { "role": "model", "parts": [ { "text": "Some text." } ] }, + "finishReason": "STOP" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let candidate = try jsonDecoder.decode(Candidate.self, from: jsonData) + + XCTAssertNil( + candidate.urlContextMetadata, + "urlContextMetadata should be nil if `urlMetadata` is not provided in the candidate." + ) + XCTAssertEqual(candidate.content.role, "model") + let part = try XCTUnwrap(candidate.content.parts.first) + let textPart = try XCTUnwrap(part as? TextPart) + XCTAssertEqual(textPart.text, "Some text.") + XCTAssertFalse(textPart.isThought) + XCTAssertEqual(candidate.finishReason, .stop) + } }