Skip to content

Commit bd3fd78

Browse files
committed
Adding hybrid count token implementation - First Pass
1 parent e206713 commit bd3fd78

File tree

4 files changed

+38
-6
lines changed

4 files changed

+38
-6
lines changed

packages/vertexai/src/methods/chrome-adapter.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import { isChrome } from '@firebase/util';
1919
import {
2020
Content,
21+
CountTokensRequest,
2122
GenerateContentRequest,
2223
InferenceMode,
2324
Role
@@ -102,6 +103,18 @@ export class ChromeAdapter {
102103
const stream = await session.promptStreaming(prompt.content);
103104
return ChromeAdapter.toStreamResponse(stream);
104105
}
106+
async countTokens(request: CountTokensRequest): Promise<Response> {
107+
const options = this.onDeviceParams || {};
108+
const prompts = ChromeAdapter.toInitialPrompts(request.contents);
109+
const session = await this.session(options);
110+
const tokenCount = await session.countPromptTokens(prompts);
111+
return {
112+
json: async () => ({
113+
totalTokens: tokenCount,
114+
totalBillableCharacters: 0,
115+
})
116+
} as Response;
117+
}
105118
private static isOnDeviceRequest(request: GenerateContentRequest): boolean {
106119
// Returns false if the prompt is empty.
107120
if (request.contents.length === 0) {

packages/vertexai/src/methods/count-tokens.test.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import { countTokens } from './count-tokens';
2525
import { CountTokensRequest } from '../types';
2626
import { ApiSettings } from '../types/internal';
2727
import { Task } from '../requests/request';
28+
import { ChromeAdapter } from './chrome-adapter';
2829

2930
use(sinonChai);
3031
use(chaiAsPromised);
@@ -52,7 +53,8 @@ describe('countTokens()', () => {
5253
const result = await countTokens(
5354
fakeApiSettings,
5455
'model',
55-
fakeRequestParams
56+
fakeRequestParams,
57+
new ChromeAdapter()
5658
);
5759
expect(result.totalTokens).to.equal(6);
5860
expect(result.totalBillableCharacters).to.equal(16);
@@ -77,7 +79,8 @@ describe('countTokens()', () => {
7779
const result = await countTokens(
7880
fakeApiSettings,
7981
'model',
80-
fakeRequestParams
82+
fakeRequestParams,
83+
new ChromeAdapter()
8184
);
8285
expect(result.totalTokens).to.equal(1837);
8386
expect(result.totalBillableCharacters).to.equal(117);
@@ -104,7 +107,8 @@ describe('countTokens()', () => {
104107
const result = await countTokens(
105108
fakeApiSettings,
106109
'model',
107-
fakeRequestParams
110+
fakeRequestParams,
111+
new ChromeAdapter()
108112
);
109113
expect(result.totalTokens).to.equal(258);
110114
expect(result).to.not.have.property('totalBillableCharacters');
@@ -127,7 +131,7 @@ describe('countTokens()', () => {
127131
json: mockResponse.json
128132
} as Response);
129133
await expect(
130-
countTokens(fakeApiSettings, 'model', fakeRequestParams)
134+
countTokens(fakeApiSettings, 'model', fakeRequestParams, new ChromeAdapter())
131135
).to.be.rejectedWith(/404.*not found/);
132136
expect(mockFetch).to.be.called;
133137
});

packages/vertexai/src/methods/count-tokens.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import {
2222
} from '../types';
2323
import { Task, makeRequest } from '../requests/request';
2424
import { ApiSettings } from '../types/internal';
25+
import { ChromeAdapter } from './chrome-adapter';
2526

26-
export async function countTokens(
27+
export async function countTokensOnCloud(
2728
apiSettings: ApiSettings,
2829
model: string,
2930
params: CountTokensRequest,
@@ -39,3 +40,17 @@ export async function countTokens(
3940
);
4041
return response.json();
4142
}
43+
44+
export async function countTokens(
45+
apiSettings: ApiSettings,
46+
model: string,
47+
params: CountTokensRequest,
48+
chromeAdapter: ChromeAdapter,
49+
requestOptions?: RequestOptions
50+
): Promise<CountTokensResponse> {
51+
if (await chromeAdapter.isAvailable(params)) {
52+
return (await chromeAdapter.countTokens(params)).json();
53+
} else {
54+
return countTokensOnCloud(apiSettings, model, params, requestOptions);
55+
}
56+
}

packages/vertexai/src/models/generative-model.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,6 @@ export class GenerativeModel extends VertexAIModel {
154154
request: CountTokensRequest | string | Array<string | Part>
155155
): Promise<CountTokensResponse> {
156156
const formattedParams = formatGenerateContentInput(request);
157-
return countTokens(this._apiSettings, this.model, formattedParams);
157+
return countTokens(this._apiSettings, this.model, formattedParams, this.chromeAdapter);
158158
}
159159
}

0 commit comments

Comments
 (0)