Skip to content

Commit 0fcadb9

Browse files
authored
[Vertex AI] Parameterize integration tests for Vertex and Dev API (#14540)
1 parent 52bae69 commit 0fcadb9

File tree

5 files changed

+163
-65
lines changed

5 files changed

+163
-65
lines changed

FirebaseVertexAI/Sources/VertexAI.swift

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,35 +25,18 @@ import Foundation
2525
public class VertexAI {
2626
// MARK: - Public APIs
2727

28-
/// The default `VertexAI` instance.
29-
///
30-
/// - Parameter location: The region identifier, defaulting to `us-central1`; see [Vertex AI
31-
/// regions
32-
/// ](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions)
33-
/// for a list of supported regions.
34-
/// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`.
35-
public static func vertexAI(location: String = "us-central1") -> VertexAI {
36-
guard let app = FirebaseApp.app() else {
37-
fatalError("No instance of the default Firebase app was found.")
38-
}
39-
let vertexInstance = vertexAI(app: app, location: location)
40-
assert(vertexInstance.apiConfig.service == .vertexAI)
41-
assert(vertexInstance.apiConfig.service.endpoint == .firebaseVertexAIProd)
42-
assert(vertexInstance.apiConfig.version == .v1beta)
43-
44-
return vertexInstance
45-
}
46-
47-
/// Creates an instance of `VertexAI` configured with a custom `FirebaseApp`.
28+
/// Creates an instance of `VertexAI`.
4829
///
4930
/// - Parameters:
50-
/// - app: The custom `FirebaseApp` used for initialization.
31+
/// - app: A custom `FirebaseApp` used for initialization; if not specified, uses the default
32+
/// ``FirebaseApp``.
5133
/// - location: The region identifier, defaulting to `us-central1`; see
5234
/// [Vertex AI locations]
5335
/// (https://firebase.google.com/docs/vertex-ai/locations?platform=ios#available-locations)
5436
/// for a list of supported locations.
5537
/// - Returns: A `VertexAI` instance, configured with the custom `FirebaseApp`.
56-
public static func vertexAI(app: FirebaseApp, location: String = "us-central1") -> VertexAI {
38+
public static func vertexAI(app: FirebaseApp? = nil,
39+
location: String = "us-central1") -> VertexAI {
5740
let vertexInstance = vertexAI(app: app, location: location, apiConfig: defaultVertexAIAPIConfig)
5841
assert(vertexInstance.apiConfig.service == .vertexAI)
5942
assert(vertexInstance.apiConfig.service.endpoint == .firebaseVertexAIProd)
@@ -160,25 +143,12 @@ public class VertexAI {
160143
let location: String?
161144

162145
static let defaultVertexAIAPIConfig = APIConfig(service: .vertexAI, version: .v1beta)
163-
static let defaultDeveloperAPIConfig = APIConfig(
164-
service: .developer(endpoint: .generativeLanguage),
165-
version: .v1beta
166-
)
167146

168-
static func developerAPI(apiConfig: APIConfig = defaultDeveloperAPIConfig) -> VertexAI {
169-
guard let app = FirebaseApp.app() else {
147+
static func vertexAI(app: FirebaseApp?, location: String?, apiConfig: APIConfig) -> VertexAI {
148+
guard let app = app ?? FirebaseApp.app() else {
170149
fatalError("No instance of the default Firebase app was found.")
171150
}
172151

173-
return developerAPI(app: app, apiConfig: apiConfig)
174-
}
175-
176-
static func developerAPI(app: FirebaseApp,
177-
apiConfig: APIConfig = defaultDeveloperAPIConfig) -> VertexAI {
178-
return vertexAI(app: app, location: nil, apiConfig: apiConfig)
179-
}
180-
181-
static func vertexAI(app: FirebaseApp, location: String?, apiConfig: APIConfig) -> VertexAI {
182152
os_unfair_lock_lock(&instancesLock)
183153

184154
// Unlock before the function returns.
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import FirebaseAuth
16+
import FirebaseCore
17+
import FirebaseStorage
18+
import FirebaseVertexAI
19+
import Testing
20+
import VertexAITestApp
21+
22+
@testable import struct FirebaseVertexAI.APIConfig
23+
24+
@Suite(.serialized)
25+
struct GenerateContentIntegrationTests {
26+
static let vertexV1Config = APIConfig(service: .vertexAI, version: .v1)
27+
static let vertexV1BetaConfig = APIConfig(service: .vertexAI, version: .v1beta)
28+
static let developerV1BetaConfig = APIConfig(
29+
service: .developer(endpoint: .generativeLanguage),
30+
version: .v1beta
31+
)
32+
33+
// Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
34+
static let generationConfig = GenerationConfig(
35+
temperature: 0.0,
36+
topP: 0.0,
37+
topK: 1,
38+
responseMIMEType: "text/plain"
39+
)
40+
static let systemInstruction = ModelContent(
41+
role: "system",
42+
parts: "You are a friendly and helpful assistant."
43+
)
44+
static let safetySettings = [
45+
SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove),
46+
SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove),
47+
SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockLowAndAbove),
48+
SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove),
49+
SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
50+
]
51+
// Candidates and total token counts may differ slightly between runs due to whitespace tokens.
52+
let tokenCountAccuracy = 1
53+
54+
let storage: Storage
55+
let userID1: String
56+
57+
init() async throws {
58+
let authResult = try await Auth.auth().signIn(
59+
withEmail: Credentials.emailAddress1,
60+
password: Credentials.emailPassword1
61+
)
62+
userID1 = authResult.user.uid
63+
64+
storage = Storage.storage()
65+
}
66+
67+
@Test(arguments: [vertexV1Config, vertexV1BetaConfig, developerV1BetaConfig])
68+
func generateContent(_ apiConfig: APIConfig) async throws {
69+
let model = GenerateContentIntegrationTests.model(apiConfig: apiConfig)
70+
let prompt = "Where is Google headquarters located? Answer with the city name only."
71+
72+
let response = try await model.generateContent(prompt)
73+
74+
let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
75+
#expect(text == "Mountain View")
76+
77+
let usageMetadata = try #require(response.usageMetadata)
78+
#expect(usageMetadata.promptTokenCount == 21)
79+
#expect(usageMetadata.candidatesTokenCount.isEqual(to: 3, accuracy: tokenCountAccuracy))
80+
#expect(usageMetadata.totalTokenCount.isEqual(to: 24, accuracy: tokenCountAccuracy))
81+
#expect(usageMetadata.promptTokensDetails.count == 1)
82+
let promptTokensDetails = try #require(usageMetadata.promptTokensDetails.first)
83+
#expect(promptTokensDetails.modality == .text)
84+
#expect(promptTokensDetails.tokenCount == usageMetadata.promptTokenCount)
85+
#expect(usageMetadata.candidatesTokensDetails.count == 1)
86+
let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first)
87+
#expect(candidatesTokensDetails.modality == .text)
88+
#expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount)
89+
}
90+
91+
static func model(apiConfig: APIConfig) -> GenerativeModel {
92+
return instance(apiConfig: apiConfig).generativeModel(
93+
modelName: "gemini-2.0-flash",
94+
generationConfig: generationConfig,
95+
safetySettings: safetySettings,
96+
tools: [],
97+
toolConfig: .init(functionCallingConfig: .none()),
98+
systemInstruction: systemInstruction
99+
)
100+
}
101+
102+
// TODO(andrewheard): Move this helper to a file in the Utilities folder.
103+
static func instance(apiConfig: APIConfig) -> VertexAI {
104+
switch apiConfig.service {
105+
case .vertexAI:
106+
return VertexAI.vertexAI(app: nil, location: "us-central1", apiConfig: apiConfig)
107+
case .developer:
108+
return VertexAI.vertexAI(app: nil, location: nil, apiConfig: apiConfig)
109+
}
110+
}
111+
}
112+
113+
// TODO(andrewheard): Move this extension to a file in the Utilities folder.
114+
extension Numeric where Self: Strideable, Self.Stride.Magnitude: Comparable {
115+
func isEqual(to other: Self, accuracy: Self.Stride) -> Bool {
116+
return distance(to: other).magnitude < accuracy.magnitude
117+
}
118+
}

FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,6 @@ final class IntegrationTests: XCTestCase {
6969

7070
// MARK: - Generate Content
7171

72-
func testGenerateContent() async throws {
73-
let prompt = "Where is Google headquarters located? Answer with the city name only."
74-
75-
let response = try await model.generateContent(prompt)
76-
77-
let text = try XCTUnwrap(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
78-
XCTAssertEqual(text, "Mountain View")
79-
let usageMetadata = try XCTUnwrap(response.usageMetadata)
80-
XCTAssertEqual(usageMetadata.promptTokenCount, 21)
81-
XCTAssertEqual(usageMetadata.candidatesTokenCount, 3, accuracy: tokenCountAccuracy)
82-
XCTAssertEqual(usageMetadata.totalTokenCount, 24, accuracy: tokenCountAccuracy)
83-
XCTAssertEqual(usageMetadata.promptTokensDetails.count, 1)
84-
let promptTokensDetails = try XCTUnwrap(usageMetadata.promptTokensDetails.first)
85-
XCTAssertEqual(promptTokensDetails.modality, .text)
86-
XCTAssertEqual(promptTokensDetails.tokenCount, usageMetadata.promptTokenCount)
87-
XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
88-
let candidatesTokensDetails = try XCTUnwrap(usageMetadata.candidatesTokensDetails.first)
89-
XCTAssertEqual(candidatesTokensDetails.modality, .text)
90-
XCTAssertEqual(candidatesTokensDetails.tokenCount, usageMetadata.candidatesTokenCount)
91-
}
92-
9372
func testGenerateContentStream() async throws {
9473
let expectedText = """
9574
1. Mercury

FirebaseVertexAI/Tests/TestApp/VertexAITestApp.xcodeproj/project.pbxproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
8692F29E2CC9477800539E8F /* FirebaseVertexAI in Frameworks */ = {isa = PBXBuildFile; productRef = 8692F29D2CC9477800539E8F /* FirebaseVertexAI */; };
2323
8698D7462CD3CF3600ABA833 /* FirebaseAppTestUtils.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8698D7452CD3CF2F00ABA833 /* FirebaseAppTestUtils.swift */; };
2424
8698D7482CD4332B00ABA833 /* TestAppCheckProviderFactory.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8698D7472CD4332B00ABA833 /* TestAppCheckProviderFactory.swift */; };
25+
86D77DFC2D7A5340003D155D /* GenerateContentIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */; };
2526
/* End PBXBuildFile section */
2627

2728
/* Begin PBXContainerItemProxy section */
@@ -49,6 +50,7 @@
4950
868A7C552CCC271300E449DD /* TestApp.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = TestApp.entitlements; sourceTree = "<group>"; };
5051
8698D7452CD3CF2F00ABA833 /* FirebaseAppTestUtils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FirebaseAppTestUtils.swift; sourceTree = "<group>"; };
5152
8698D7472CD4332B00ABA833 /* TestAppCheckProviderFactory.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TestAppCheckProviderFactory.swift; sourceTree = "<group>"; };
53+
86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerateContentIntegrationTests.swift; sourceTree = "<group>"; };
5254
/* End PBXFileReference section */
5355

5456
/* Begin PBXFrameworksBuildPhase section */
@@ -126,6 +128,7 @@
126128
children = (
127129
868A7C4D2CCC1F4700E449DD /* Credentials.swift */,
128130
8661386D2CC943DE00F4B78E /* IntegrationTests.swift */,
131+
86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */,
129132
864F8F702D4980D60002EA7E /* ImagenIntegrationTests.swift */,
130133
862218802D04E08D007ED2D4 /* IntegrationTestUtils.swift */,
131134
);
@@ -273,6 +276,7 @@
273276
868A7C4F2CCC229F00E449DD /* Credentials.swift in Sources */,
274277
864F8F712D4980DD0002EA7E /* ImagenIntegrationTests.swift in Sources */,
275278
862218812D04E098007ED2D4 /* IntegrationTestUtils.swift in Sources */,
279+
86D77DFC2D7A5340003D155D /* GenerateContentIntegrationTests.swift in Sources */,
276280
8661386E2CC943DE00F4B78E /* IntegrationTests.swift in Sources */,
277281
);
278282
runOnlyForDeploymentPostprocessing = 0;

FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,22 @@ class VertexComponentTests: XCTestCase {
5151
XCTAssertNotNil(NSClassFromString("FIRVertexAIComponent"))
5252
}
5353

54-
/// Tests that a vertex instance can be created properly using the default Firebase pp.
54+
/// Tests that a vertex instance can be created properly using the default Firebase app.
5555
func testVertexInstanceCreation_defaultApp() throws {
56+
let vertex = VertexAI.vertexAI()
57+
58+
XCTAssertNotNil(vertex)
59+
XCTAssertEqual(vertex.firebaseInfo.projectID, VertexComponentTests.projectID)
60+
XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey)
61+
XCTAssertEqual(vertex.location, "us-central1")
62+
XCTAssertEqual(vertex.apiConfig.service, .vertexAI)
63+
XCTAssertEqual(vertex.apiConfig.service.endpoint, .firebaseVertexAIProd)
64+
XCTAssertEqual(vertex.apiConfig.version, .v1beta)
65+
}
66+
67+
/// Tests that a vertex instance can be created properly using the default Firebase app and custom
68+
/// location.
69+
func testVertexInstanceCreation_defaultApp_customLocation() throws {
5670
let vertex = VertexAI.vertexAI(location: location)
5771

5872
XCTAssertNotNil(vertex)
@@ -121,8 +135,16 @@ class VertexComponentTests: XCTestCase {
121135
}
122136

123137
func testSameAppAndDifferentAPI_newInstanceCreated() throws {
124-
let vertex1 = VertexAI.vertexAI(app: VertexComponentTests.app)
125-
let vertex2 = VertexAI.developerAPI(app: VertexComponentTests.app)
138+
let vertex1 = VertexAI.vertexAI(
139+
app: VertexComponentTests.app,
140+
location: location,
141+
apiConfig: APIConfig(service: .vertexAI, version: .v1beta)
142+
)
143+
let vertex2 = VertexAI.vertexAI(
144+
app: VertexComponentTests.app,
145+
location: location,
146+
apiConfig: APIConfig(service: .vertexAI, version: .v1)
147+
)
126148

127149
// Ensure they are different instances.
128150
XCTAssert(vertex1 !== vertex2)
@@ -168,7 +190,8 @@ class VertexComponentTests: XCTestCase {
168190

169191
func testModelResourceName_developerAPI_generativeLanguage() throws {
170192
let app = try XCTUnwrap(VertexComponentTests.app)
171-
let vertex = VertexAI.developerAPI(app: app)
193+
let apiConfig = APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1beta)
194+
let vertex = VertexAI.vertexAI(app: app, location: nil, apiConfig: apiConfig)
172195
let model = "test-model-name"
173196

174197
let modelResourceName = vertex.modelResourceName(modelName: model)
@@ -182,7 +205,7 @@ class VertexComponentTests: XCTestCase {
182205
service: .developer(endpoint: .firebaseVertexAIStaging),
183206
version: .v1beta
184207
)
185-
let vertex = VertexAI.developerAPI(app: app, apiConfig: apiConfig)
208+
let vertex = VertexAI.vertexAI(app: app, location: nil, apiConfig: apiConfig)
186209
let model = "test-model-name"
187210
let projectID = vertex.firebaseInfo.projectID
188211

@@ -208,7 +231,11 @@ class VertexComponentTests: XCTestCase {
208231

209232
func testGenerativeModel_developerAPI() async throws {
210233
let app = try XCTUnwrap(VertexComponentTests.app)
211-
let vertex = VertexAI.developerAPI(app: app)
234+
let apiConfig = APIConfig(
235+
service: .developer(endpoint: .firebaseVertexAIStaging),
236+
version: .v1beta
237+
)
238+
let vertex = VertexAI.vertexAI(app: app, location: nil, apiConfig: apiConfig)
212239
let modelResourceName = vertex.modelResourceName(modelName: modelName)
213240

214241
let generativeModel = vertex.generativeModel(
@@ -218,6 +245,6 @@ class VertexComponentTests: XCTestCase {
218245

219246
XCTAssertEqual(generativeModel.modelResourceName, modelResourceName)
220247
XCTAssertEqual(generativeModel.systemInstruction, systemInstruction)
221-
XCTAssertEqual(generativeModel.apiConfig, VertexAI.defaultDeveloperAPIConfig)
248+
XCTAssertEqual(generativeModel.apiConfig, apiConfig)
222249
}
223250
}

0 commit comments

Comments
 (0)