Skip to content

Commit 1e96338

Browse files
committed
[Vertex AI] Add test app for integration tests
1 parent 7a10b3b commit 1e96338

File tree

9 files changed

+858
-0
lines changed

9 files changed

+858
-0
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import FirebaseCore
16+
import FirebaseVertexAI
17+
import XCTest
18+
19+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
20+
final class IntegrationTests: XCTestCase {
21+
// Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
22+
let generationConfig = GenerationConfig(
23+
temperature: 0.0,
24+
topP: 0.0,
25+
topK: 1,
26+
responseMIMEType: "text/plain"
27+
)
28+
let systemInstruction = ModelContent(
29+
role: "system",
30+
parts: "You are a friendly and helpful assistant."
31+
)
32+
let safetySettings = [
33+
SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove, method: .probability),
34+
SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove, method: .severity),
35+
SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockLowAndAbove),
36+
SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove),
37+
SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
38+
]
39+
40+
var vertex: VertexAI!
41+
var model: GenerativeModel!
42+
43+
override func setUp() async throws {
44+
try XCTSkipIf(ProcessInfo.processInfo.environment["VertexAIRunIntegrationTests"] == nil, """
45+
Vertex AI integration tests skipped; to enable them, set the VertexAIRunIntegrationTests \
46+
environment variable in Xcode or CI jobs.
47+
""")
48+
49+
vertex = VertexAI.vertexAI()
50+
model = vertex.generativeModel(
51+
modelName: "gemini-1.5-flash",
52+
generationConfig: generationConfig,
53+
safetySettings: safetySettings,
54+
tools: [],
55+
toolConfig: .init(functionCallingConfig: .none()),
56+
systemInstruction: systemInstruction
57+
)
58+
}
59+
60+
// MARK: - Generate Content
61+
62+
func testGenerateContent() async throws {
63+
let prompt = "Where is Google headquarters located? Answer with the city name only."
64+
65+
let response = try await model.generateContent(prompt)
66+
67+
let text = try XCTUnwrap(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
68+
XCTAssertEqual(text, "Mountain View")
69+
}
70+
71+
// MARK: - Count Tokens
72+
73+
func testCountTokens_text() async throws {
74+
let prompt = "Why is the sky blue?"
75+
model = vertex.generativeModel(
76+
modelName: "gemini-1.5-pro",
77+
generationConfig: generationConfig,
78+
safetySettings: [
79+
SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove, method: .severity),
80+
SafetySetting(harmCategory: .hateSpeech, threshold: .blockMediumAndAbove),
81+
SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockOnlyHigh),
82+
SafetySetting(harmCategory: .dangerousContent, threshold: .blockNone),
83+
SafetySetting(harmCategory: .civicIntegrity, threshold: .off, method: .probability),
84+
],
85+
toolConfig: .init(functionCallingConfig: .auto()),
86+
systemInstruction: systemInstruction
87+
)
88+
89+
let response = try await model.countTokens(prompt)
90+
91+
XCTAssertEqual(response.totalTokens, 14)
92+
XCTAssertEqual(response.totalBillableCharacters, 51)
93+
}
94+
95+
#if canImport(UIKit)
96+
func testCountTokens_image_inlineData() async throws {
97+
guard let image = UIImage(systemName: "cloud") else {
98+
XCTFail("Image not found.")
99+
return
100+
}
101+
102+
let response = try await model.countTokens(image)
103+
104+
XCTAssertEqual(response.totalTokens, 266)
105+
XCTAssertEqual(response.totalBillableCharacters, 35)
106+
}
107+
#endif // canImport(UIKit)
108+
109+
func testCountTokens_image_fileData() async throws {
110+
let fileData = FileDataPart(
111+
uri: "gs://ios-opensource-samples.appspot.com/ios/public/blank.jpg",
112+
mimeType: "image/jpeg"
113+
)
114+
115+
let response = try await model.countTokens(fileData)
116+
117+
XCTAssertEqual(response.totalTokens, 266)
118+
XCTAssertEqual(response.totalBillableCharacters, 35)
119+
}
120+
121+
func testCountTokens_functionCalling() async throws {
122+
let sumDeclaration = FunctionDeclaration(
123+
name: "sum",
124+
description: "Adds two integers.",
125+
parameters: ["x": .integer(), "y": .integer()]
126+
)
127+
model = vertex.generativeModel(
128+
modelName: "gemini-1.5-flash",
129+
tools: [.functionDeclarations([sumDeclaration])],
130+
toolConfig: .init(functionCallingConfig: .any(allowedFunctionNames: ["sum"]))
131+
)
132+
let prompt = "What is 10 + 32?"
133+
let sumCall = FunctionCallPart(name: "sum", args: ["x": .number(10), "y": .number(32)])
134+
let sumResponse = FunctionResponsePart(name: "sum", response: ["result": .number(42)])
135+
136+
let response = try await model.countTokens([
137+
ModelContent(role: "user", parts: prompt),
138+
ModelContent(role: "model", parts: sumCall),
139+
ModelContent(role: "function", parts: sumResponse),
140+
])
141+
142+
XCTAssertEqual(response.totalTokens, 24)
143+
XCTAssertEqual(response.totalBillableCharacters, 71)
144+
}
145+
146+
func testCountTokens_jsonSchema() async throws {
147+
model = vertex.generativeModel(
148+
modelName: "gemini-1.5-flash",
149+
generationConfig: GenerationConfig(
150+
responseMIMEType: "application/json",
151+
responseSchema: Schema.object(properties: [
152+
"startDate": .string(format: .custom("date")),
153+
"yearsSince": .integer(format: .custom("int16")),
154+
"hoursSince": .integer(format: .int32),
155+
"minutesSince": .integer(format: .int64),
156+
])
157+
)
158+
)
159+
let prompt = "It is 2050-01-01, how many years, hours and minutes since 2000-01-01?"
160+
161+
let response = try await model.countTokens(prompt)
162+
163+
XCTAssertEqual(response.totalTokens, 34)
164+
XCTAssertEqual(response.totalBillableCharacters, 59)
165+
}
166+
}

0 commit comments

Comments
 (0)