diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index 25f3f3181e0..9a7d5c70e7c 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -31,6 +31,9 @@ - [changed] The default request timeout is now 180 seconds instead of the platform-default value of 60 seconds for a `URLRequest`; this timeout may still be customized in `RequestOptions`. (#13722) +- [changed] The response from `GenerativeModel.countTokens(...)` now includes + `systemInstruction`, `tools` and `generationConfig` in the `totalTokens` and + `totalBillableCharacters` counts, where applicable. (#13813) # 11.3.0 - [added] Added `Decodable` conformance for `FunctionResponse`. (#13606) diff --git a/FirebaseVertexAI/Sources/CountTokensRequest.swift b/FirebaseVertexAI/Sources/CountTokensRequest.swift index 6b052da19e6..128cb3b8ce6 100644 --- a/FirebaseVertexAI/Sources/CountTokensRequest.swift +++ b/FirebaseVertexAI/Sources/CountTokensRequest.swift @@ -17,7 +17,12 @@ import Foundation @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) struct CountTokensRequest { let model: String + let contents: [ModelContent] + let systemInstruction: ModelContent? + let tools: [Tool]? + let generationConfig: GenerationConfig? + let options: RequestOptions } @@ -49,6 +54,9 @@ public struct CountTokensResponse { extension CountTokensRequest: Encodable { enum CodingKeys: CodingKey { case contents + case systemInstruction + case tools + case generationConfig } } diff --git a/FirebaseVertexAI/Sources/GenerativeModel.swift b/FirebaseVertexAI/Sources/GenerativeModel.swift index a5a8933e435..dc069d88d03 100644 --- a/FirebaseVertexAI/Sources/GenerativeModel.swift +++ b/FirebaseVertexAI/Sources/GenerativeModel.swift @@ -272,6 +272,9 @@ public final class GenerativeModel { let countTokensRequest = try CountTokensRequest( model: modelResourceName, contents: content(), + systemInstruction: systemInstruction, + tools: tools, + generationConfig: generationConfig, options: requestOptions ) return try await generativeAIService.loadRequest(request: countTokensRequest) diff --git a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift index 44eac05be22..0ccbb98e83a 100644 --- a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift @@ -19,7 +19,16 @@ import XCTest @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) final class IntegrationTests: XCTestCase { // Set temperature, topP and topK to lowest allowed values to make responses more deterministic. - let generationConfig = GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1) + let generationConfig = GenerationConfig( + temperature: 0.0, + topP: 0.0, + topK: 1, + responseMIMEType: "text/plain" + ) + let systemInstruction = ModelContent( + role: "system", + parts: "You are a friendly and helpful assistant." + ) var vertex: VertexAI! var model: GenerativeModel! @@ -40,7 +49,9 @@ final class IntegrationTests: XCTestCase { vertex = VertexAI.vertexAI() model = vertex.generativeModel( modelName: "gemini-1.5-flash", - generationConfig: generationConfig + generationConfig: generationConfig, + tools: [], + systemInstruction: systemInstruction ) } @@ -68,7 +79,7 @@ final class IntegrationTests: XCTestCase { let response = try await model.countTokens(prompt) - XCTAssertEqual(response.totalTokens, 6) - XCTAssertEqual(response.totalBillableCharacters, 16) + XCTAssertEqual(response.totalTokens, 14) + XCTAssertEqual(response.totalBillableCharacters, 51) } } diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index 6956160b072..c5e8332d2b8 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -1234,6 +1234,48 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.totalBillableCharacters, 16) } + func testCountTokens_succeeds_allOptions() async throws { + MockURLProtocol.requestHandler = try httpRequestHandler( + forResource: "unary-success-total-tokens", + withExtension: "json" + ) + let generationConfig = GenerationConfig( + temperature: 0.5, + topP: 0.9, + topK: 3, + candidateCount: 1, + maxOutputTokens: 1024, + stopSequences: ["test-stop"], + responseMIMEType: "text/plain" + ) + let sumFunction = FunctionDeclaration( + name: "sum", + description: "Add two integers.", + parameters: ["x": .integer(), "y": .integer()] + ) + let systemInstruction = ModelContent( + role: "system", + parts: "You are a calculator. Use the provided tools." + ) + model = GenerativeModel( + name: testModelResourceName, + projectID: "my-project-id", + apiKey: "API_KEY", + generationConfig: generationConfig, + tools: [Tool(functionDeclarations: [sumFunction])], + systemInstruction: systemInstruction, + requestOptions: RequestOptions(), + appCheck: nil, + auth: nil, + urlSession: urlSession + ) + + let response = try await model.countTokens("Why is the sky blue?") + + XCTAssertEqual(response.totalTokens, 6) + XCTAssertEqual(response.totalBillableCharacters, 16) + } + func testCountTokens_succeeds_noBillableCharacters() async throws { MockURLProtocol.requestHandler = try httpRequestHandler( forResource: "unary-success-no-billable-characters",