Skip to content

Commit 11fe4a0

Browse files
authored
[Vertex AI] Add countTokens support for Developer API via VinF (#14644)
1 parent 0b6091c commit 11fe4a0

File tree

8 files changed

+71
-20
lines changed

8 files changed

+71
-20
lines changed

FirebaseVertexAI/Sources/GenerativeModel.swift

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ public final class GenerativeModel: Sendable {
2323
/// Model name prefix to identify Gemini models.
2424
static let geminiModelNamePrefix = "gemini-"
2525

26-
/// The resource name of the model in the backend; has the format "models/model-name".
26+
/// The name of the model, for example "gemini-2.0-flash".
27+
let modelName: String
28+
29+
/// The model resource name corresponding with `modelName` in the backend.
2730
let modelResourceName: String
2831

2932
/// Configuration for the backend API used by this model.
@@ -53,8 +56,13 @@ public final class GenerativeModel: Sendable {
5356
/// Initializes a new remote model with the given parameters.
5457
///
5558
/// - Parameters:
56-
/// - modelResourceName: The resource name of the model to use, for example
57-
/// `"projects/{project-id}/locations/{location-id}/publishers/google/models/{model-name}"`.
59+
/// - modelName: The name of the model, for example "gemini-2.0-flash".
60+
/// - modelResourceName: The model resource name corresponding with `modelName` in the backend.
61+
/// The form depends on the backend and will be one of:
62+
/// - Vertex AI via Vertex AI in Firebase:
63+
/// `"projects/{projectID}/locations/{locationID}/publishers/google/models/{modelName}"`
64+
/// - Developer API via Vertex AI in Firebase: `"projects/{projectID}/models/{modelName}"`
65+
/// - Developer API via Generative Language: `"models/{modelName}"`
5866
/// - firebaseInfo: Firebase data used by the SDK, including project ID and API key.
5967
/// - apiConfig: Configuration for the backend API used by this model.
6068
/// - generationConfig: The content generation parameters your model should use.
@@ -65,7 +73,8 @@ public final class GenerativeModel: Sendable {
6573
/// only text content is supported.
6674
/// - requestOptions: Configuration parameters for sending requests to the backend.
6775
/// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
68-
init(modelResourceName: String,
76+
init(modelName: String,
77+
modelResourceName: String,
6978
firebaseInfo: FirebaseInfo,
7079
apiConfig: APIConfig,
7180
generationConfig: GenerationConfig? = nil,
@@ -75,6 +84,7 @@ public final class GenerativeModel: Sendable {
7584
systemInstruction: ModelContent? = nil,
7685
requestOptions: RequestOptions,
7786
urlSession: URLSession = .shared) {
87+
self.modelName = modelName
7888
self.modelResourceName = modelResourceName
7989
self.apiConfig = apiConfig
8090
generativeAIService = GenerativeAIService(
@@ -275,8 +285,20 @@ public final class GenerativeModel: Sendable {
275285
content.map { ModelContent(role: nil, parts: $0.parts) }
276286
}
277287

288+
// When using the Developer API via the Firebase backend, the model name of the
289+
// `GenerateContentRequest` nested in the `CountTokensRequest` must be of the form
290+
// "models/model-name". This field is unaltered by the Firebase backend before forwarding the
291+
// request to the Generative Language backend, which expects the form "models/model-name".
292+
let generateContentRequestModelResourceName = switch apiConfig.service {
293+
case .vertexAI, .developer(endpoint: .generativeLanguage):
294+
modelResourceName
295+
case .developer(endpoint: .firebaseVertexAIProd),
296+
.developer(endpoint: .firebaseVertexAIStaging):
297+
"models/\(modelName)"
298+
}
299+
278300
let generateContentRequest = GenerateContentRequest(
279-
model: modelResourceName,
301+
model: generateContentRequestModelResourceName,
280302
contents: requestContent,
281303
generationConfig: generationConfig,
282304
safetySettings: safetySettings,
@@ -287,7 +309,9 @@ public final class GenerativeModel: Sendable {
287309
apiMethod: .countTokens,
288310
options: requestOptions
289311
)
290-
let countTokensRequest = CountTokensRequest(generateContentRequest: generateContentRequest)
312+
let countTokensRequest = CountTokensRequest(
313+
modelResourceName: modelResourceName, generateContentRequest: generateContentRequest
314+
)
291315

292316
return try await generativeAIService.loadRequest(request: countTokensRequest)
293317
}

FirebaseVertexAI/Sources/Types/Internal/Requests/CountTokensRequest.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import Foundation
1616

1717
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
1818
struct CountTokensRequest {
19+
let modelResourceName: String
20+
1921
let generateContentRequest: GenerateContentRequest
2022
}
2123

@@ -30,7 +32,7 @@ extension CountTokensRequest: GenerativeAIRequest {
3032
var url: URL {
3133
let version = apiConfig.version.rawValue
3234
let endpoint = apiConfig.service.endpoint.rawValue
33-
return URL(string: "\(endpoint)/\(version)/\(generateContentRequest.model):countTokens")!
35+
return URL(string: "\(endpoint)/\(version)/\(modelResourceName):countTokens")!
3436
}
3537
}
3638

FirebaseVertexAI/Sources/VertexAI.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ public class VertexAI {
8080
}
8181

8282
return GenerativeModel(
83+
modelName: modelName,
8384
modelResourceName: modelResourceName(modelName: modelName),
8485
firebaseInfo: firebaseInfo,
8586
apiConfig: apiConfig,
@@ -240,13 +241,11 @@ public class VertexAI {
240241

241242
private func developerModelResourceName(modelName: String) -> String {
242243
switch apiConfig.service.endpoint {
243-
case .firebaseVertexAIStaging:
244+
case .firebaseVertexAIStaging, .firebaseVertexAIProd:
244245
let projectID = firebaseInfo.projectID
245246
return "projects/\(projectID)/models/\(modelName)"
246247
case .generativeLanguage:
247248
return "models/\(modelName)"
248-
default:
249-
fatalError("The Developer API is not supported on '\(apiConfig.service.endpoint)'.")
250249
}
251250
}
252251

FirebaseVertexAI/Tests/TestApp/Tests/Integration/CountTokensIntegrationTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ struct CountTokensIntegrationTests {
102102

103103
@Test(arguments: [
104104
/* System instructions are not supported on the v1 Developer API. */
105-
InstanceConfig.developerV1,
105+
InstanceConfig.developerV1Spark,
106106
])
107107
func countTokens_text_systemInstruction_unsupported(_ config: InstanceConfig) async throws {
108108
let model = VertexAI.componentInstance(config).generativeModel(

FirebaseVertexAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@ struct InstanceConfig {
3232
static let vertexV1BetaStaging = InstanceConfig(
3333
apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseVertexAIStaging), version: .v1beta)
3434
)
35-
static let developerV1 = InstanceConfig(
35+
static let developerV1Beta = InstanceConfig(
36+
apiConfig: APIConfig(service: .developer(endpoint: .firebaseVertexAIProd), version: .v1beta)
37+
)
38+
static let developerV1Spark = InstanceConfig(
3639
appName: FirebaseAppNames.spark,
3740
apiConfig: APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1)
3841
)
39-
static let developerV1Beta = InstanceConfig(
42+
static let developerV1BetaSpark = InstanceConfig(
4043
appName: FirebaseAppNames.spark,
4144
apiConfig: APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1beta)
4245
)
@@ -45,8 +48,9 @@ struct InstanceConfig {
4548
vertexV1Staging,
4649
vertexV1Beta,
4750
vertexV1BetaStaging,
48-
developerV1,
4951
developerV1Beta,
52+
developerV1Spark,
53+
developerV1BetaSpark,
5054
]
5155

5256
static let vertexV1AppCheckNotConfigured = InstanceConfig(

FirebaseVertexAI/Tests/Unit/ChatTests.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ import FirebaseCore
2020

2121
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
2222
final class ChatTests: XCTestCase {
23+
let modelName = "test-model-name"
24+
let modelResourceName = "projects/my-project/locations/us-central1/models/test-model-name"
25+
2326
var urlSession: URLSession!
2427

2528
override func setUp() {
@@ -59,7 +62,8 @@ final class ChatTests: XCTestCase {
5962
options: FirebaseOptions(googleAppID: "ignore",
6063
gcmSenderID: "ignore"))
6164
let model = GenerativeModel(
62-
modelResourceName: "my-model",
65+
modelName: modelName,
66+
modelResourceName: modelResourceName,
6367
firebaseInfo: FirebaseInfo(
6468
projectID: "my-project-id",
6569
apiKey: "API_KEY",

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ final class GenerativeModelTests: XCTestCase {
5656
blocked: false
5757
),
5858
].sorted()
59+
let testModelName = "test-model"
5960
let testModelResourceName =
6061
"projects/test-project-id/locations/test-location/publishers/google/models/test-model"
6162
let apiConfig = VertexAI.defaultVertexAIAPIConfig
@@ -70,6 +71,7 @@ final class GenerativeModelTests: XCTestCase {
7071
configuration.protocolClasses = [MockURLProtocol.self]
7172
urlSession = try XCTUnwrap(URLSession(configuration: configuration))
7273
model = GenerativeModel(
74+
modelName: testModelName,
7375
modelResourceName: testModelResourceName,
7476
firebaseInfo: testFirebaseInfo(),
7577
apiConfig: apiConfig,
@@ -275,8 +277,8 @@ final class GenerativeModelTests: XCTestCase {
275277
subdirectory: vertexSubdirectory
276278
)
277279
let model = GenerativeModel(
278-
// Model name is prefixed with "models/".
279-
modelResourceName: "models/test-model",
280+
modelName: testModelName,
281+
modelResourceName: testModelResourceName,
280282
firebaseInfo: testFirebaseInfo(),
281283
apiConfig: apiConfig,
282284
tools: nil,
@@ -399,6 +401,7 @@ final class GenerativeModelTests: XCTestCase {
399401
func testGenerateContent_appCheck_validToken() async throws {
400402
let appCheckToken = "test-valid-token"
401403
model = GenerativeModel(
404+
modelName: testModelName,
402405
modelResourceName: testModelResourceName,
403406
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
404407
apiConfig: apiConfig,
@@ -420,6 +423,7 @@ final class GenerativeModelTests: XCTestCase {
420423
func testGenerateContent_dataCollectionOff() async throws {
421424
let appCheckToken = "test-valid-token"
422425
model = GenerativeModel(
426+
modelName: testModelName,
423427
modelResourceName: testModelResourceName,
424428
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken),
425429
privateAppID: true),
@@ -442,6 +446,7 @@ final class GenerativeModelTests: XCTestCase {
442446

443447
func testGenerateContent_appCheck_tokenRefreshError() async throws {
444448
model = GenerativeModel(
449+
modelName: testModelName,
445450
modelResourceName: testModelResourceName,
446451
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
447452
apiConfig: apiConfig,
@@ -463,6 +468,7 @@ final class GenerativeModelTests: XCTestCase {
463468
func testGenerateContent_auth_validAuthToken() async throws {
464469
let authToken = "test-valid-token"
465470
model = GenerativeModel(
471+
modelName: testModelName,
466472
modelResourceName: testModelResourceName,
467473
firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: authToken)),
468474
apiConfig: apiConfig,
@@ -483,6 +489,7 @@ final class GenerativeModelTests: XCTestCase {
483489

484490
func testGenerateContent_auth_nilAuthToken() async throws {
485491
model = GenerativeModel(
492+
modelName: testModelName,
486493
modelResourceName: testModelResourceName,
487494
firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: nil)),
488495
apiConfig: apiConfig,
@@ -503,7 +510,8 @@ final class GenerativeModelTests: XCTestCase {
503510

504511
func testGenerateContent_auth_authTokenRefreshError() async throws {
505512
model = GenerativeModel(
506-
modelResourceName: "my-model",
513+
modelName: testModelName,
514+
modelResourceName: testModelResourceName,
507515
firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(error: AuthErrorFake())),
508516
apiConfig: apiConfig,
509517
tools: nil,
@@ -900,6 +908,7 @@ final class GenerativeModelTests: XCTestCase {
900908
)
901909
let requestOptions = RequestOptions(timeout: expectedTimeout)
902910
model = GenerativeModel(
911+
modelName: testModelName,
903912
modelResourceName: testModelResourceName,
904913
firebaseInfo: testFirebaseInfo(),
905914
apiConfig: apiConfig,
@@ -1204,6 +1213,7 @@ final class GenerativeModelTests: XCTestCase {
12041213
func testGenerateContentStream_appCheck_validToken() async throws {
12051214
let appCheckToken = "test-valid-token"
12061215
model = GenerativeModel(
1216+
modelName: testModelName,
12071217
modelResourceName: testModelResourceName,
12081218
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
12091219
apiConfig: apiConfig,
@@ -1225,6 +1235,7 @@ final class GenerativeModelTests: XCTestCase {
12251235

12261236
func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
12271237
model = GenerativeModel(
1238+
modelName: testModelName,
12281239
modelResourceName: testModelResourceName,
12291240
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
12301241
apiConfig: apiConfig,
@@ -1375,6 +1386,7 @@ final class GenerativeModelTests: XCTestCase {
13751386
)
13761387
let requestOptions = RequestOptions(timeout: expectedTimeout)
13771388
model = GenerativeModel(
1389+
modelName: testModelName,
13781390
modelResourceName: testModelResourceName,
13791391
firebaseInfo: testFirebaseInfo(),
13801392
apiConfig: apiConfig,
@@ -1451,6 +1463,7 @@ final class GenerativeModelTests: XCTestCase {
14511463
parts: "You are a calculator. Use the provided tools."
14521464
)
14531465
model = GenerativeModel(
1466+
modelName: testModelName,
14541467
modelResourceName: testModelResourceName,
14551468
firebaseInfo: testFirebaseInfo(),
14561469
apiConfig: apiConfig,
@@ -1511,6 +1524,7 @@ final class GenerativeModelTests: XCTestCase {
15111524
)
15121525
let requestOptions = RequestOptions(timeout: expectedTimeout)
15131526
model = GenerativeModel(
1527+
modelName: testModelName,
15141528
modelResourceName: testModelResourceName,
15151529
firebaseInfo: testFirebaseInfo(),
15161530
apiConfig: apiConfig,

FirebaseVertexAI/Tests/Unit/Types/Internal/Requests/CountTokensRequestTests.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ final class CountTokensRequestTests: XCTestCase {
5252
apiMethod: .countTokens,
5353
options: requestOptions
5454
)
55-
let request = CountTokensRequest(generateContentRequest: generateContentRequest)
55+
let request = CountTokensRequest(
56+
modelResourceName: modelResourceName, generateContentRequest: generateContentRequest
57+
)
5658

5759
let jsonData = try encoder.encode(request)
5860

@@ -86,7 +88,9 @@ final class CountTokensRequestTests: XCTestCase {
8688
apiMethod: .countTokens,
8789
options: requestOptions
8890
)
89-
let request = CountTokensRequest(generateContentRequest: generateContentRequest)
91+
let request = CountTokensRequest(
92+
modelResourceName: modelResourceName, generateContentRequest: generateContentRequest
93+
)
9094

9195
let jsonData = try encoder.encode(request)
9296

0 commit comments

Comments
 (0)