Skip to content
Merged
70 changes: 70 additions & 0 deletions FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,53 @@ final class GenerativeModelGoogleAITests: XCTestCase {
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 160)
}

func testGenerateContent_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context",
withExtension: "json",
subdirectory: googleAISubdirectory
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(
retrievedURL,
URL(string: "https://berkshirehathaway.com")
)
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 424)
}

func testGenerateContent_success_urlContext_mixedValidity() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context-mixed-validity",
withExtension: "json",
subdirectory: googleAISubdirectory
)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 3)

let paywallURLMetadata = urlContextMetadata.urlMetadata[0]
XCTAssertEqual(paywallURLMetadata.retrievalStatus, .error)

let successURLMetadata = urlContextMetadata.urlMetadata[1]
XCTAssertEqual(successURLMetadata.retrievalStatus, .success)

let errorURLMetadata = urlContextMetadata.urlMetadata[2]
XCTAssertEqual(errorURLMetadata.retrievalStatus, .error)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
Expand Down Expand Up @@ -644,4 +691,27 @@ final class GenerativeModelGoogleAITests: XCTestCase {
let lastResponse = try XCTUnwrap(responses.last)
XCTAssertEqual(lastResponse.text, "text8")
}

func testGenerateContentStream_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "streaming-success-url-context",
withExtension: "txt",
subdirectory: googleAISubdirectory
)

var responses = [GenerateContentResponse]()
let stream = try model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

let firstResponse = try XCTUnwrap(responses.first)
let candidate = try XCTUnwrap(firstResponse.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(retrievedURL, URL(string: "https://google.com"))
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
}
}
87 changes: 87 additions & 0 deletions FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,71 @@ final class GenerativeModelVertexAITests: XCTestCase {
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 371)
}

func testGenerateContent_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context",
withExtension: "json",
subdirectory: vertexSubdirectory
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(
retrievedURL,
URL(string: "https://berkshirehathaway.com")
)
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 34)
XCTAssertEqual(usageMetadata.thoughtsTokenCount, 36)
}

func testGenerateContent_success_urlContext_mixedValidity() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context-mixed-validity",
withExtension: "json",
subdirectory: vertexSubdirectory
)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 3)

let paywallURLMetadata = urlContextMetadata.urlMetadata[0]
XCTAssertEqual(paywallURLMetadata.retrievalStatus, .error)

let successURLMetadata = urlContextMetadata.urlMetadata[1]
XCTAssertEqual(successURLMetadata.retrievalStatus, .success)

let errorURLMetadata = urlContextMetadata.urlMetadata[2]
XCTAssertEqual(errorURLMetadata.retrievalStatus, .error)
}

func testGenerateContent_success_urlContext_retrievedURLPresentOnErrorStatus() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context-missing-retrievedurl",
withExtension: "json",
subdirectory: vertexSubdirectory
)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(retrievedURL.absoluteString, "https://example.com/8")
XCTAssertEqual(urlMetadata.retrievalStatus, .error)
}

func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-image-invalid-safety-ratings",
Expand Down Expand Up @@ -1720,6 +1785,28 @@ final class GenerativeModelVertexAITests: XCTestCase {
XCTAssertEqual(responses, 1)
}

func testGenerateContentStream_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "streaming-success-url-context",
withExtension: "txt",
subdirectory: vertexSubdirectory
)

var responses = [GenerateContentResponse]()
let stream = try model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

let firstResponse = try XCTUnwrap(responses.first)
let candidate = try XCTUnwrap(firstResponse.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
XCTAssertEqual(urlMetadata.retrievedURL?.absoluteString, "https://google.com")
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
}

// MARK: - Count Tokens

func testCountTokens_succeeds() async throws {
Expand Down
8 changes: 5 additions & 3 deletions FirebaseAI/Tests/Unit/MockURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class MockURLProtocol: URLProtocol, @unchecked Sendable {
URLResponse,
AsyncLineSequence<URL.AsyncBytes>?
))?

override class func canInit(with request: URLRequest) -> Bool {
#if os(watchOS)
print("MockURLProtocol cannot be used on watchOS.")
Expand All @@ -33,13 +34,14 @@ class MockURLProtocol: URLProtocol, @unchecked Sendable {
override class func canonicalRequest(for request: URLRequest) -> URLRequest { return request }

override func startLoading() {
guard let requestHandler = MockURLProtocol.requestHandler else {
fatalError("`requestHandler` is nil.")
}
guard let client = client else {
fatalError("`client` is nil.")
}

guard let requestHandler = MockURLProtocol.requestHandler else {
fatalError("No request handler set.")
}

Task {
let (response, stream) = try requestHandler(self.request)
client.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed)
Expand Down
51 changes: 51 additions & 0 deletions FirebaseAI/Tests/Unit/Types/GenerateContentResponseTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import XCTest

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class GenerateContentResponseTests: XCTestCase {
let jsonDecoder = JSONDecoder()

// MARK: - GenerateContentResponse Computed Properties

func testGenerateContentResponse_inlineDataParts_success() throws {
Expand Down Expand Up @@ -106,4 +108,53 @@ final class GenerateContentResponseTests: XCTestCase {
"functionCalls should be empty when there are no candidates."
)
}

// MARK: - Decoding Tests

func testDecodeCandidate_emptyURLMetadata_urlContextMetadataIsNil() throws {
let json = """
{
"content": { "role": "model", "parts": [ { "text": "Some text." } ] },
"finishReason": "STOP",
"urlContextMetadata": { "urlMetadata": [] }
}
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let candidate = try jsonDecoder.decode(Candidate.self, from: jsonData)

XCTAssertNil(
candidate.urlContextMetadata,
"urlContextMetadata should be nil if the `urlMetadata` array is empty in the candidate."
)
XCTAssertEqual(candidate.content.role, "model")
let part = try XCTUnwrap(candidate.content.parts.first)
let textPart = try XCTUnwrap(part as? TextPart)
XCTAssertEqual(textPart.text, "Some text.")
XCTAssertFalse(textPart.isThought)
XCTAssertEqual(candidate.finishReason, .stop)
}

func testDecodeCandidate_missingURLMetadata_urlContextMetadataIsNil() throws {
let json = """
{
"content": { "role": "model", "parts": [ { "text": "Some text." } ] },
"finishReason": "STOP"
}
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let candidate = try jsonDecoder.decode(Candidate.self, from: jsonData)

XCTAssertNil(
candidate.urlContextMetadata,
"urlContextMetadata should be nil if `urlMetadata` is not provided in the candidate."
)
XCTAssertEqual(candidate.content.role, "model")
let part = try XCTUnwrap(candidate.content.parts.first)
let textPart = try XCTUnwrap(part as? TextPart)
XCTAssertEqual(textPart.text, "Some text.")
XCTAssertFalse(textPart.isThought)
XCTAssertEqual(candidate.finishReason, .stop)
}
}
Loading