Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,53 @@ final class GenerativeModelGoogleAITests: XCTestCase {
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 160)
}

func testGenerateContent_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context",
withExtension: "json",
subdirectory: googleAISubdirectory
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(
retrievedURL,
URL(string: "https://berkshirehathaway.com")
)
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 424)
}

func testGenerateContent_success_urlContext_mixedValidity() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context-mixed-validity",
withExtension: "json",
subdirectory: googleAISubdirectory
)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 3)

let paywallURLMetadata = urlContextMetadata.urlMetadata[0]
XCTAssertEqual(paywallURLMetadata.retrievalStatus, .error)

let successURLMetadata = urlContextMetadata.urlMetadata[1]
XCTAssertEqual(successURLMetadata.retrievalStatus, .success)

let errorURLMetadata = urlContextMetadata.urlMetadata[2]
XCTAssertEqual(errorURLMetadata.retrievalStatus, .error)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
Expand Down Expand Up @@ -644,4 +691,27 @@ final class GenerativeModelGoogleAITests: XCTestCase {
let lastResponse = try XCTUnwrap(responses.last)
XCTAssertEqual(lastResponse.text, "text8")
}

func testGenerateContentStream_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "streaming-success-url-context",
withExtension: "txt",
subdirectory: googleAISubdirectory
)

var responses = [GenerateContentResponse]()
let stream = try model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

let firstResponse = try XCTUnwrap(responses.first)
let candidate = try XCTUnwrap(firstResponse.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(retrievedURL, URL(string: "https://google.com"))
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
}
}
88 changes: 88 additions & 0 deletions FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,71 @@ final class GenerativeModelVertexAITests: XCTestCase {
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 371)
}

func testGenerateContent_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context",
withExtension: "json",
subdirectory: vertexSubdirectory
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(
retrievedURL,
URL(string: "https://berkshirehathaway.com")
)
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 34)
XCTAssertEqual(usageMetadata.thoughtsTokenCount, 36)
}

func testGenerateContent_success_urlContext_mixedValidity() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context-mixed-validity",
withExtension: "json",
subdirectory: vertexSubdirectory
)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 3)

let paywallURLMetadata = urlContextMetadata.urlMetadata[0]
XCTAssertEqual(paywallURLMetadata.retrievalStatus, .error)

let successURLMetadata = urlContextMetadata.urlMetadata[1]
XCTAssertEqual(successURLMetadata.retrievalStatus, .success)

let errorURLMetadata = urlContextMetadata.urlMetadata[2]
XCTAssertEqual(errorURLMetadata.retrievalStatus, .error)
}

func testGenerateContent_success_urlContext_retrievedURLPresentOnErrorStatus() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context-missing-retrievedurl",
withExtension: "json",
subdirectory: vertexSubdirectory
)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(retrievedURL.absoluteString, "https://example.com/8")
XCTAssertEqual(urlMetadata.retrievalStatus, .error)
}

func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-image-invalid-safety-ratings",
Expand Down Expand Up @@ -1720,6 +1785,29 @@ final class GenerativeModelVertexAITests: XCTestCase {
XCTAssertEqual(responses, 1)
}

func testGenerateContentStream_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "streaming-success-url-context",
withExtension: "txt",
subdirectory: vertexSubdirectory
)

var responses = [GenerateContentResponse]()
let stream = try model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

let firstResponse = try XCTUnwrap(responses.first)
let candidate = try XCTUnwrap(firstResponse.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(retrievedURL, URL(string: "https://google.com"))
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
}

// MARK: - Count Tokens

func testCountTokens_succeeds() async throws {
Expand Down
8 changes: 5 additions & 3 deletions FirebaseAI/Tests/Unit/MockURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class MockURLProtocol: URLProtocol, @unchecked Sendable {
URLResponse,
AsyncLineSequence<URL.AsyncBytes>?
))?

override class func canInit(with request: URLRequest) -> Bool {
#if os(watchOS)
print("MockURLProtocol cannot be used on watchOS.")
Expand All @@ -33,13 +34,14 @@ class MockURLProtocol: URLProtocol, @unchecked Sendable {
override class func canonicalRequest(for request: URLRequest) -> URLRequest { return request }

override func startLoading() {
guard let requestHandler = MockURLProtocol.requestHandler else {
fatalError("`requestHandler` is nil.")
}
guard let client = client else {
fatalError("`client` is nil.")
}

guard let requestHandler = MockURLProtocol.requestHandler else {
fatalError("No request handler set.")
}

Task {
let (response, stream) = try requestHandler(self.request)
client.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed)
Expand Down
51 changes: 51 additions & 0 deletions FirebaseAI/Tests/Unit/Types/GenerateContentResponseTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import XCTest

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class GenerateContentResponseTests: XCTestCase {
let jsonDecoder = JSONDecoder()

// MARK: - GenerateContentResponse Computed Properties

func testGenerateContentResponse_inlineDataParts_success() throws {
Expand Down Expand Up @@ -106,4 +108,53 @@ final class GenerateContentResponseTests: XCTestCase {
"functionCalls should be empty when there are no candidates."
)
}

// MARK: - Decoding Tests

func testDecodeCandidate_emptyURLMetadata_urlContextMetadataIsNil() throws {
let json = """
{
"content": { "role": "model", "parts": [ { "text": "Some text." } ] },
"finishReason": "STOP",
"urlContextMetadata": { "urlMetadata": [] }
}
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let candidate = try jsonDecoder.decode(Candidate.self, from: jsonData)

XCTAssertNil(
candidate.urlContextMetadata,
"urlContextMetadata should be nil if the `urlMetadata` array is empty in the candidate."
)
XCTAssertEqual(candidate.content.role, "model")
let part = try XCTUnwrap(candidate.content.parts.first)
let textPart = try XCTUnwrap(part as? TextPart)
XCTAssertEqual(textPart.text, "Some text.")
XCTAssertFalse(textPart.isThought)
XCTAssertEqual(candidate.finishReason, .stop)
}

func testDecodeCandidate_missingURLMetadata_urlContextMetadataIsNil() throws {
let json = """
{
"content": { "role": "model", "parts": [ { "text": "Some text." } ] },
"finishReason": "STOP"
}
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let candidate = try jsonDecoder.decode(Candidate.self, from: jsonData)

XCTAssertNil(
candidate.urlContextMetadata,
"urlContextMetadata should be nil if `urlMetadata` is not provided in the candidate."
)
XCTAssertEqual(candidate.content.role, "model")
let part = try XCTUnwrap(candidate.content.parts.first)
let textPart = try XCTUnwrap(part as? TextPart)
XCTAssertEqual(textPart.text, "Some text.")
XCTAssertFalse(textPart.isThought)
XCTAssertEqual(candidate.finishReason, .stop)
}
}
Loading