Skip to content

Commit 9ba8fc0

Browse files
authored
[Vertex AI] Add Developer API encoding CountTokensRequest (#14512)
1 parent dd4a403 commit 9ba8fc0

File tree

9 files changed

+368
-45
lines changed

9 files changed

+368
-45
lines changed

FirebaseVertexAI/Sources/GenerateContentRequest.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ import Foundation
1818
struct GenerateContentRequest: Sendable {
1919
/// Model name.
2020
let model: String
21+
2122
let contents: [ModelContent]
2223
let generationConfig: GenerationConfig?
2324
let safetySettings: [SafetySetting]?
2425
let tools: [Tool]?
2526
let toolConfig: ToolConfig?
2627
let systemInstruction: ModelContent?
28+
2729
let apiConfig: APIConfig
2830
let apiMethod: APIMethod
2931
let options: RequestOptions
@@ -32,13 +34,30 @@ struct GenerateContentRequest: Sendable {
3234
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
3335
extension GenerateContentRequest: Encodable {
3436
enum CodingKeys: String, CodingKey {
37+
case model
3538
case contents
3639
case generationConfig
3740
case safetySettings
3841
case tools
3942
case toolConfig
4043
case systemInstruction
4144
}
45+
46+
func encode(to encoder: any Encoder) throws {
47+
var container = encoder.container(keyedBy: CodingKeys.self)
48+
// The model name only needs to be encoded when this `GenerateContentRequest` instance is used
49+
// in a `CountTokensRequest` (calling `countTokens`). When calling `generateContent` or
50+
// `generateContentStream`, the `model` field is populated in the backend from the `url`.
51+
if apiMethod == .countTokens {
52+
try container.encode(model, forKey: .model)
53+
}
54+
try container.encode(contents, forKey: .contents)
55+
try container.encodeIfPresent(generationConfig, forKey: .generationConfig)
56+
try container.encodeIfPresent(safetySettings, forKey: .safetySettings)
57+
try container.encodeIfPresent(tools, forKey: .tools)
58+
try container.encodeIfPresent(toolConfig, forKey: .toolConfig)
59+
try container.encodeIfPresent(systemInstruction, forKey: .systemInstruction)
60+
}
4261
}
4362

4463
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)

FirebaseVertexAI/Sources/GenerativeModel.swift

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,15 +260,31 @@ public final class GenerativeModel: Sendable {
260260
/// - Returns: The results of running the model's tokenizer on the input; contains
261261
/// ``CountTokensResponse/totalTokens``.
262262
public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
263-
let countTokensRequest = CountTokensRequest(
263+
let requestContent = switch apiConfig.service {
264+
case .vertexAI:
265+
content
266+
case .developer:
267+
// The `role` defaults to "user" but is ignored in `countTokens`. However, it is erroneously
268+
// erroneously counted towards the prompt and total token count when using the Developer API
269+
// backend; set to `nil` to avoid token count discrepancies between `countTokens` and
270+
// `generateContent` and the two backend APIs.
271+
content.map { ModelContent(role: nil, parts: $0.parts) }
272+
}
273+
274+
let generateContentRequest = GenerateContentRequest(
264275
model: modelResourceName,
265-
contents: content,
266-
systemInstruction: systemInstruction,
267-
tools: tools,
276+
contents: requestContent,
268277
generationConfig: generationConfig,
278+
safetySettings: safetySettings,
279+
tools: tools,
280+
toolConfig: toolConfig,
281+
systemInstruction: systemInstruction,
269282
apiConfig: apiConfig,
283+
apiMethod: .countTokens,
270284
options: requestOptions
271285
)
286+
let countTokensRequest = CountTokensRequest(generateContentRequest: generateContentRequest)
287+
272288
return try await generativeAIService.loadRequest(request: countTokensRequest)
273289
}
274290

FirebaseVertexAI/Sources/CountTokensRequest.swift renamed to FirebaseVertexAI/Sources/Types/Internal/Requests/CountTokensRequest.swift

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,21 @@ import Foundation
1616

1717
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
1818
struct CountTokensRequest {
19-
let model: String
20-
21-
let contents: [ModelContent]
22-
let systemInstruction: ModelContent?
23-
let tools: [Tool]?
24-
let generationConfig: GenerationConfig?
25-
26-
let apiConfig: APIConfig
27-
let options: RequestOptions
19+
let generateContentRequest: GenerateContentRequest
2820
}
2921

3022
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
3123
extension CountTokensRequest: GenerativeAIRequest {
3224
typealias Response = CountTokensResponse
3325

26+
var options: RequestOptions { generateContentRequest.options }
27+
28+
var apiConfig: APIConfig { generateContentRequest.apiConfig }
29+
3430
var url: URL {
35-
URL(string:
36-
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):countTokens")!
31+
let version = apiConfig.version.rawValue
32+
let endpoint = apiConfig.service.endpoint.rawValue
33+
return URL(string: "\(endpoint)/\(version)/\(generateContentRequest.model):countTokens")!
3734
}
3835
}
3936

@@ -57,12 +54,42 @@ public struct CountTokensResponse {
5754

5855
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
5956
extension CountTokensRequest: Encodable {
60-
enum CodingKeys: CodingKey {
57+
enum VertexCodingKeys: CodingKey {
6158
case contents
6259
case systemInstruction
6360
case tools
6461
case generationConfig
6562
}
63+
64+
enum DeveloperCodingKeys: CodingKey {
65+
case generateContentRequest
66+
}
67+
68+
func encode(to encoder: any Encoder) throws {
69+
switch apiConfig.service {
70+
case .vertexAI:
71+
try encodeForVertexAI(to: encoder)
72+
case .developer:
73+
try encodeForDeveloper(to: encoder)
74+
}
75+
}
76+
77+
private func encodeForVertexAI(to encoder: any Encoder) throws {
78+
var container = encoder.container(keyedBy: VertexCodingKeys.self)
79+
try container.encode(generateContentRequest.contents, forKey: .contents)
80+
try container.encodeIfPresent(
81+
generateContentRequest.systemInstruction, forKey: .systemInstruction
82+
)
83+
try container.encodeIfPresent(generateContentRequest.tools, forKey: .tools)
84+
try container.encodeIfPresent(
85+
generateContentRequest.generationConfig, forKey: .generationConfig
86+
)
87+
}
88+
89+
private func encodeForDeveloper(to encoder: any Encoder) throws {
90+
var container = encoder.container(keyedBy: DeveloperCodingKeys.self)
91+
try container.encode(generateContentRequest, forKey: .generateContentRequest)
92+
}
6693
}
6794

6895
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)

FirebaseVertexAI/Tests/TestApp/Sources/Constants.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ public enum FirebaseAppNames {
2121
}
2222

2323
public enum ModelNames {
24+
public static let gemini2Flash = "gemini-2.0-flash-001"
2425
public static let gemini2FlashLite = "gemini-2.0-flash-lite-001"
2526
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import FirebaseAuth
16+
import FirebaseCore
17+
import FirebaseStorage
18+
import FirebaseVertexAI
19+
import Testing
20+
import VertexAITestApp
21+
22+
@testable import struct FirebaseVertexAI.APIConfig
23+
24+
@Suite(.serialized)
25+
struct CountTokensIntegrationTests {
26+
let generationConfig = GenerationConfig(
27+
temperature: 1.2,
28+
topP: 0.95,
29+
topK: 32,
30+
candidateCount: 1,
31+
maxOutputTokens: 8192,
32+
presencePenalty: 1.5,
33+
frequencyPenalty: 1.75,
34+
stopSequences: ["cat", "dog", "bird"]
35+
)
36+
let safetySettings = [
37+
SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove),
38+
SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove),
39+
SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockLowAndAbove),
40+
SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove),
41+
SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
42+
]
43+
let systemInstruction = ModelContent(
44+
role: "system",
45+
parts: "You are a friendly and helpful assistant."
46+
)
47+
48+
@Test(arguments: InstanceConfig.allConfigs)
49+
func countTokens_text(_ config: InstanceConfig) async throws {
50+
let prompt = "Why is the sky blue?"
51+
let model = VertexAI.componentInstance(config).generativeModel(
52+
modelName: ModelNames.gemini2Flash,
53+
generationConfig: generationConfig,
54+
safetySettings: safetySettings
55+
)
56+
57+
let response = try await model.countTokens(prompt)
58+
59+
#expect(response.totalTokens == 6)
60+
switch config.apiConfig.service {
61+
case .vertexAI:
62+
#expect(response.totalBillableCharacters == 16)
63+
case .developer:
64+
#expect(response.totalBillableCharacters == nil)
65+
}
66+
#expect(response.promptTokensDetails.count == 1)
67+
let promptTokensDetails = try #require(response.promptTokensDetails.first)
68+
#expect(promptTokensDetails.modality == .text)
69+
#expect(promptTokensDetails.tokenCount == response.totalTokens)
70+
}
71+
72+
@Test(arguments: [
73+
InstanceConfig.vertexV1,
74+
InstanceConfig.vertexV1Beta,
75+
/* System instructions are not supported on the v1 Developer API. */
76+
InstanceConfig.developerV1Beta,
77+
])
78+
func countTokens_text_systemInstruction(_ config: InstanceConfig) async throws {
79+
let model = VertexAI.componentInstance(config).generativeModel(
80+
modelName: ModelNames.gemini2Flash,
81+
generationConfig: generationConfig,
82+
safetySettings: safetySettings,
83+
systemInstruction: systemInstruction // Not supported on the v1 Developer API
84+
)
85+
86+
let response = try await model.countTokens("What is your favourite colour?")
87+
88+
#expect(response.totalTokens == 14)
89+
switch config.apiConfig.service {
90+
case .vertexAI:
91+
#expect(response.totalBillableCharacters == 61)
92+
case .developer:
93+
#expect(response.totalBillableCharacters == nil)
94+
}
95+
#expect(response.promptTokensDetails.count == 1)
96+
let promptTokensDetails = try #require(response.promptTokensDetails.first)
97+
#expect(promptTokensDetails.modality == .text)
98+
#expect(promptTokensDetails.tokenCount == response.totalTokens)
99+
}
100+
101+
@Test(arguments: [
102+
/* System instructions are not supported on the v1 Developer API. */
103+
InstanceConfig.developerV1,
104+
])
105+
func countTokens_text_systemInstruction_unsupported(_ config: InstanceConfig) async throws {
106+
let model = VertexAI.componentInstance(config).generativeModel(
107+
modelName: ModelNames.gemini2Flash,
108+
systemInstruction: systemInstruction // Not supported on the v1 Developer API
109+
)
110+
111+
try await #require(
112+
throws: BackendError.self,
113+
"""
114+
If this test fails (i.e., `countTokens` succeeds), remove \(config) from this test and add it
115+
to `countTokens_text_systemInstruction`.
116+
""",
117+
performing: {
118+
try await model.countTokens("What is your favourite colour?")
119+
}
120+
)
121+
}
122+
}

FirebaseVertexAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,8 @@ import FirebaseVertexAI
1919
import Testing
2020
import VertexAITestApp
2121

22-
@testable import struct FirebaseVertexAI.APIConfig
23-
2422
@Suite(.serialized)
2523
struct GenerateContentIntegrationTests {
26-
static let vertexV1Config =
27-
InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1))
28-
static let vertexV1BetaConfig =
29-
InstanceConfig(apiConfig: APIConfig(service: .vertexAI, version: .v1beta))
30-
static let developerV1Config = InstanceConfig(
31-
appName: FirebaseAppNames.spark,
32-
apiConfig: APIConfig(
33-
service: .developer(endpoint: .generativeLanguage), version: .v1
34-
)
35-
)
36-
static let developerV1BetaConfig = InstanceConfig(
37-
appName: FirebaseAppNames.spark,
38-
apiConfig: APIConfig(
39-
service: .developer(endpoint: .generativeLanguage), version: .v1beta
40-
)
41-
)
42-
static let allConfigs =
43-
[vertexV1Config, vertexV1BetaConfig, developerV1Config, developerV1BetaConfig]
44-
4524
// Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
4625
let generationConfig = GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1)
4726
let safetySettings = [
@@ -67,7 +46,7 @@ struct GenerateContentIntegrationTests {
6746
storage = Storage.storage()
6847
}
6948

70-
@Test(arguments: allConfigs)
49+
@Test(arguments: InstanceConfig.allConfigs)
7150
func generateContent(_ config: InstanceConfig) async throws {
7251
let model = VertexAI.componentInstance(config).generativeModel(
7352
modelName: ModelNames.gemini2FlashLite,
@@ -98,10 +77,10 @@ struct GenerateContentIntegrationTests {
9877
@Test(
9978
"Generate an enum and provide a system instruction",
10079
arguments: [
101-
vertexV1Config,
102-
vertexV1BetaConfig,
80+
InstanceConfig.vertexV1,
81+
InstanceConfig.vertexV1Beta,
10382
/* System instructions are not supported on the v1 Developer API. */
104-
developerV1BetaConfig,
83+
InstanceConfig.developerV1Beta,
10584
]
10685
)
10786
func generateContentEnum(_ config: InstanceConfig) async throws {

0 commit comments

Comments
 (0)