Skip to content
Merged
89 changes: 89 additions & 0 deletions FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,74 @@ final class GenerativeModelGoogleAITests: XCTestCase {
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 160)
}

func testGenerateContent_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context",
withExtension: "json",
subdirectory: googleAISubdirectory
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(
retrievedURL,
URL(string: "https://berkshirehathaway.com")
)
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 424)
}

func testGenerateContent_success_urlContext_mixedValidity() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context-mixed-validity",
withExtension: "json",
subdirectory: googleAISubdirectory
)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 3)

let paywallURLMetadata = urlContextMetadata.urlMetadata[0]
XCTAssertEqual(paywallURLMetadata.retrievalStatus, .error)

let successURLMetadata = urlContextMetadata.urlMetadata[1]
XCTAssertEqual(successURLMetadata.retrievalStatus, .success)

let errorURLMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata[2])
XCTAssertEqual(errorURLMetadata.retrievalStatus, .error)
}

func testGenerateContent_success_urlContext_emptyURLMetadata() async throws {
let json = """
{
"candidates": [
{
"content": { "role": "model", "parts": [ { "text": "Some text." } ] },
"finishReason": "STOP",
"urlContextMetadata": { "urlMetadata": [] }
}
]
}
"""
MockURLProtocol.requestHandler = nil
MockURLProtocol.dataRequestHandler = try GenerativeModelTestUtil.httpRequestHandler(json: json)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
XCTAssertNil(candidate.urlContextMetadata)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
Expand Down Expand Up @@ -644,4 +712,25 @@ final class GenerativeModelGoogleAITests: XCTestCase {
let lastResponse = try XCTUnwrap(responses.last)
XCTAssertEqual(lastResponse.text, "text8")
}

func testGenerateContentStream_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "streaming-success-url-context",
withExtension: "txt",
subdirectory: googleAISubdirectory
)

var responses = [GenerateContentResponse]()
let stream = try model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

let firstResponse = try XCTUnwrap(responses.first)
let candidate = try XCTUnwrap(firstResponse.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
}
}
107 changes: 107 additions & 0 deletions FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,91 @@ final class GenerativeModelVertexAITests: XCTestCase {
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 371)
}

func testGenerateContent_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context",
withExtension: "json",
subdirectory: vertexSubdirectory
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
let retrievedURL = try XCTUnwrap(urlMetadata.retrievedURL)
XCTAssertEqual(
retrievedURL,
URL(string: "https://berkshirehathaway.com")
)
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.toolUsePromptTokenCount, 34)
XCTAssertEqual(usageMetadata.thoughtsTokenCount, 36)
}

func testGenerateContent_success_urlContext_mixedValidity() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context-mixed-validity",
withExtension: "json",
subdirectory: vertexSubdirectory
)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 3)

let paywallURLMetadata = urlContextMetadata.urlMetadata[0]
XCTAssertEqual(paywallURLMetadata.retrievalStatus, .error)

let successURLMetadata = urlContextMetadata.urlMetadata[1]
XCTAssertEqual(successURLMetadata.retrievalStatus, .success)

let errorURLMetadata = urlContextMetadata.urlMetadata[2]
XCTAssertEqual(errorURLMetadata.retrievalStatus, .error)
}

func testGenerateContent_success_urlContext_nonexistentRetrievedURL() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-url-context-missing-retrievedurl",
withExtension: "json",
subdirectory: vertexSubdirectory
)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
XCTAssertEqual(urlMetadata.retrievedURL?.absoluteString, "https://example.com/8")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There seems to be a contradiction in this test. The test name testGenerateContent_success_urlContext_missingRetrievedURL and the mock resource name unary-success-url-context-missing-retrievedurl suggest that the retrievedURL should be missing or nil. However, the assertion checks if retrievedURL?.absoluteString is equal to a specific string, which implies retrievedURL is not nil.

If the intention is to test for a missing URL, the assertion should probably be XCTAssertNil(urlMetadata.retrievedURL). If the current assertion is correct, consider renaming the test to avoid confusion.

Suggested change
XCTAssertEqual(urlMetadata.retrievedURL?.absoluteString, "https://example.com/8")
XCTAssertNil(urlMetadata.retrievedURL)

Copy link
Member Author

@paulb777 paulb777 Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed function to testGenerateContent_success_urlContext_retrievedURLPresentOnErrorStatus cc: @dlarocque @andrewheard

XCTAssertEqual(urlMetadata.retrievalStatus, .error)
}

func testGenerateContent_success_urlContext_emptyURLMetadata() async throws {
let json = """
{
"candidates": [
{
"content": { "role": "model", "parts": [ { "text": "Some text." } ] },
"finishReason": "STOP",
"urlContextMetadata": { "urlMetadata": [] }
}
]
}
"""
MockURLProtocol.requestHandler = nil
MockURLProtocol.dataRequestHandler = try GenerativeModelTestUtil.httpRequestHandler(json: json)

let response = try await model.generateContent(testPrompt)

let candidate = try XCTUnwrap(response.candidates.first)
XCTAssertNil(candidate.urlContextMetadata)
}

func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-image-invalid-safety-ratings",
Expand Down Expand Up @@ -1720,6 +1805,28 @@ final class GenerativeModelVertexAITests: XCTestCase {
XCTAssertEqual(responses, 1)
}

func testGenerateContentStream_success_urlContext() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "streaming-success-url-context",
withExtension: "txt",
subdirectory: vertexSubdirectory
)

var responses = [GenerateContentResponse]()
let stream = try model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

let firstResponse = try XCTUnwrap(responses.first)
let candidate = try XCTUnwrap(firstResponse.candidates.first)
let urlContextMetadata = try XCTUnwrap(candidate.urlContextMetadata)
XCTAssertEqual(urlContextMetadata.urlMetadata.count, 1)
let urlMetadata = try XCTUnwrap(urlContextMetadata.urlMetadata.first)
XCTAssertEqual(urlMetadata.retrievedURL?.absoluteString, "https://google.com")
XCTAssertEqual(urlMetadata.retrievalStatus, .success)
}

// MARK: - Count Tokens

func testCountTokens_succeeds() async throws {
Expand Down
56 changes: 36 additions & 20 deletions FirebaseAI/Tests/Unit/MockURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ class MockURLProtocol: URLProtocol, @unchecked Sendable {
URLResponse,
AsyncLineSequence<URL.AsyncBytes>?
))?

nonisolated(unsafe) static var dataRequestHandler: ((URLRequest) throws -> (
URLResponse,
Data?
))?

override class func canInit(with request: URLRequest) -> Bool {
#if os(watchOS)
print("MockURLProtocol cannot be used on watchOS.")
Expand All @@ -33,34 +39,44 @@ class MockURLProtocol: URLProtocol, @unchecked Sendable {
override class func canonicalRequest(for request: URLRequest) -> URLRequest { return request }

override func startLoading() {
guard let requestHandler = MockURLProtocol.requestHandler else {
fatalError("`requestHandler` is nil.")
}
guard let client = client else {
fatalError("`client` is nil.")
}

Task {
let (response, stream) = try requestHandler(self.request)
client.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed)
if let stream = stream {
do {
for try await line in stream {
guard let data = line.data(using: .utf8) else {
fatalError("Failed to convert \"\(line)\" to UTF8 data.")
if let requestHandler = MockURLProtocol.requestHandler {
Task {
let (response, stream) = try requestHandler(self.request)
client.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed)
if let stream = stream {
do {
for try await line in stream {
guard let data = line.data(using: .utf8) else {
fatalError("Failed to convert \"\(line)\" to UTF8 data.")
}
client.urlProtocol(self, didLoad: data)
// Add a newline character since AsyncLineSequence strips them when reading line by
// line;
// without the following, the whole file is delivered as a single line.
client.urlProtocol(self, didLoad: "\n".data(using: .utf8)!)
}
client.urlProtocol(self, didLoad: data)
// Add a newline character since AsyncLineSequence strips them when reading line by
// line;
// without the following, the whole file is delivered as a single line.
client.urlProtocol(self, didLoad: "\n".data(using: .utf8)!)
} catch {
client.urlProtocol(self, didFailWithError: error)
XCTFail("Unexpected failure reading lines from stream: \(error.localizedDescription)")
}
} catch {
client.urlProtocol(self, didFailWithError: error)
XCTFail("Unexpected failure reading lines from stream: \(error.localizedDescription)")
}
client.urlProtocolDidFinishLoading(self)
}
} else if let dataRequestHandler = MockURLProtocol.dataRequestHandler {
Task {
let (response, data) = try dataRequestHandler(self.request)
client.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed)
if let data = data {
client.urlProtocol(self, didLoad: data)
}
client.urlProtocolDidFinishLoading(self)
}
client.urlProtocolDidFinishLoading(self)
} else {
fatalError("No request handler set.")
}
}

Expand Down
25 changes: 25 additions & 0 deletions FirebaseAI/Tests/Unit/TestUtilities/GenerativeModelTestUtil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,31 @@ enum GenerativeModelTestUtil {
#endif // os(watchOS)
}

/// Returns an HTTP request handler that returns the given `json` string as a response.
static func httpRequestHandler(json: String,
statusCode: Int = 200) throws -> ((URLRequest) throws -> (
URLResponse,
Data?
)) {
// Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
// https://developer.apple.com/documentation/foundation/urlprotocol for details.
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#else // os(watchOS)
let data = try XCTUnwrap(json.data(using: .utf8))
return { request in
let requestURL = try XCTUnwrap(request.url)
let response = try XCTUnwrap(HTTPURLResponse(
url: requestURL,
statusCode: statusCode,
httpVersion: nil,
headerFields: nil
))
return (response, data)
}
#endif // os(watchOS)
}

static func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
URLResponse,
AsyncLineSequence<URL.AsyncBytes>?
Expand Down
Loading