Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
- [changed] The default request timeout is now 180 seconds instead of the
platform-default value of 60 seconds for a `URLRequest`; this timeout may
still be customized in `RequestOptions`. (#13722)
- [changed] The response from `GenerativeModel.countTokens(...)` now includes
`systemInstruction`, `tools` and `generationConfig` in the `totalTokens` and
`totalBillableCharacters` counts, where applicable. (#13813)

# 11.3.0
- [added] Added `Decodable` conformance for `FunctionResponse`. (#13606)
Expand Down
8 changes: 8 additions & 0 deletions FirebaseVertexAI/Sources/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ import Foundation
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct CountTokensRequest {
let model: String

let contents: [ModelContent]
let systemInstruction: ModelContent?
let tools: [Tool]?
let generationConfig: GenerationConfig?

let options: RequestOptions
}

Expand Down Expand Up @@ -49,6 +54,9 @@ public struct CountTokensResponse {
extension CountTokensRequest: Encodable {
enum CodingKeys: CodingKey {
case contents
case systemInstruction
case tools
case generationConfig
}
}

Expand Down
3 changes: 3 additions & 0 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ public final class GenerativeModel {
let countTokensRequest = try CountTokensRequest(
model: modelResourceName,
contents: content(),
systemInstruction: systemInstruction,
tools: tools,
generationConfig: generationConfig,
options: requestOptions
)
return try await generativeAIService.loadRequest(request: countTokensRequest)
Expand Down
19 changes: 15 additions & 4 deletions FirebaseVertexAI/Tests/Integration/IntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@ import XCTest
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class IntegrationTests: XCTestCase {
// Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
let generationConfig = GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1)
let generationConfig = GenerationConfig(
temperature: 0.0,
topP: 0.0,
topK: 1,
responseMIMEType: "text/plain"
)
let systemInstruction = ModelContent(
role: "system",
parts: "You are a friendly and helpful assistant."
)

var vertex: VertexAI!
var model: GenerativeModel!
Expand All @@ -40,7 +49,9 @@ final class IntegrationTests: XCTestCase {
vertex = VertexAI.vertexAI()
model = vertex.generativeModel(
modelName: "gemini-1.5-flash",
generationConfig: generationConfig
generationConfig: generationConfig,
tools: [],
systemInstruction: systemInstruction
)
}

Expand Down Expand Up @@ -68,7 +79,7 @@ final class IntegrationTests: XCTestCase {

let response = try await model.countTokens(prompt)

XCTAssertEqual(response.totalTokens, 6)
XCTAssertEqual(response.totalBillableCharacters, 16)
XCTAssertEqual(response.totalTokens, 14)
XCTAssertEqual(response.totalBillableCharacters, 51)
}
}
42 changes: 42 additions & 0 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,48 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(response.totalBillableCharacters, 16)
}

func testCountTokens_succeeds_allOptions() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "unary-success-total-tokens",
withExtension: "json"
)
let generationConfig = GenerationConfig(
temperature: 0.5,
topP: 0.9,
topK: 3,
candidateCount: 1,
maxOutputTokens: 1024,
stopSequences: ["test-stop"],
responseMIMEType: "text/plain"
)
let sumFunction = FunctionDeclaration(
name: "sum",
description: "Add two integers.",
parameters: ["x": .integer(), "y": .integer()]
)
let systemInstruction = ModelContent(
role: "system",
parts: "You are a calculator. Use the provided tools."
)
model = GenerativeModel(
name: testModelResourceName,
projectID: "my-project-id",
apiKey: "API_KEY",
generationConfig: generationConfig,
tools: [Tool(functionDeclarations: [sumFunction])],
systemInstruction: systemInstruction,
requestOptions: RequestOptions(),
appCheck: nil,
auth: nil,
urlSession: urlSession
)

let response = try await model.countTokens("Why is the sky blue?")

XCTAssertEqual(response.totalTokens, 6)
XCTAssertEqual(response.totalBillableCharacters, 16)
}

func testCountTokens_succeeds_noBillableCharacters() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "unary-success-no-billable-characters",
Expand Down
Loading