Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions packages/vertexai/src/methods/chrome-adapter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ describe('ChromeAdapter', () => {
Promise.resolve({
available: 'after-download'
}),
create: () => {}
create: () => { }
}
} as AI;
const downloadPromise = new Promise<AILanguageModel>(() => {
Expand All @@ -182,7 +182,7 @@ describe('ChromeAdapter', () => {
Promise.resolve({
available: 'after-download'
}),
create: () => {}
create: () => { }
}
} as AI;
let resolveDownload;
Expand Down Expand Up @@ -298,4 +298,87 @@ describe('ChromeAdapter', () => {
});
});
});
describe('countTokens', () => {
it('With no initial prompts', async () => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: describing this as a test for a singular input, eg it('counts tokens from a singular input', would be more self-explanatory from my perspective.

const aiProvider = {
languageModel: {
create: () => Promise.resolve({})
}
} as AI;
const inputText = "first";
const expectedCount = 10;
const model = {
countPromptTokens: _s => Promise.resolve(123),
} as AILanguageModel;
// overrides impl with stub method
const countPromptTokensStub = stub(model, 'countPromptTokens').resolves(expectedCount);
const factoryStub = stub(aiProvider.languageModel, 'create').resolves(model);
const adapter = new ChromeAdapter(
aiProvider,
'prefer_on_device',
);
const response = await adapter.countTokens({
contents: [
{ role: 'user', parts: [{ text: inputText }] }
]
});
expect(factoryStub).to.have.been.calledOnceWith({
// initialPrompts must be empty
initialPrompts: []
});
// validate count tokens gets called with the last entry from the input
expect(countPromptTokensStub).to.have.been.calledOnceWith({
role: 'user',
content: inputText
});
expect(await response.json()).to.deep.equal({
totalTokens: expectedCount
});
});
it('Extracts initial prompts and then does counts tokens', async () => {
const aiProvider = {
languageModel: {
create: () => Promise.resolve({})
}
} as AI;
const expectedCount = 10;
const model = {
countPromptTokens: _s => Promise.resolve(123),
} as AILanguageModel;
// overrides impl with stub method
const countPromptTokensStub = stub(model, 'countPromptTokens').resolves(expectedCount);
const factoryStub = stub(aiProvider.languageModel, 'create').resolves(model);
const text = ['first', 'second', 'third'];
const onDeviceParams = {
initialPrompts: [{ role: 'user', content: text[0] }]
} as AILanguageModelCreateOptionsWithSystemPrompt;
const adapter = new ChromeAdapter(
aiProvider,
'prefer_on_device',
onDeviceParams
);
const response = await adapter.countTokens({
contents: [
{ role: 'model', parts: [{ text: text[1] }] },
{ role: 'user', parts: [{ text: text[2] }] }
]
});
expect(factoryStub).to.have.been.calledOnceWith({
initialPrompts: [
{ role: 'user', content: text[0] },
// Asserts tail is passed as initial prompts, and
// role is normalized from model to assistant.
{ role: 'assistant', content: text[1] }
]
});
// validate count tokens gets called with the last entry from the input
expect(countPromptTokensStub).to.have.been.calledOnceWith({
role: 'user',
content: text[2]
});
expect(await response.json()).to.deep.equal({
totalTokens: expectedCount
});
});
});
});
15 changes: 15 additions & 0 deletions packages/vertexai/src/methods/chrome-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import { isChrome } from '@firebase/util';
import {
Content,
CountTokensRequest,
GenerateContentRequest,
InferenceMode,
Role
Expand Down Expand Up @@ -102,6 +103,20 @@ export class ChromeAdapter {
const stream = await session.promptStreaming(prompt.content);
return ChromeAdapter.toStreamResponse(stream);
}
async countTokens(request: CountTokensRequest): Promise<Response> {
const options = this.onDeviceParams || {};
options.initialPrompts ??= [];
const extractedInitialPrompts = ChromeAdapter.toInitialPrompts(request.contents);
const currentPrompt = extractedInitialPrompts.pop()!;
options.initialPrompts.push(...extractedInitialPrompts);
const session = await this.session(options);
const tokenCount = await session.countPromptTokens(currentPrompt);
return {
json: async () => ({
totalTokens: tokenCount,
})
} as Response;
}
private static isOnDeviceRequest(request: GenerateContentRequest): boolean {
// Returns false if the prompt is empty.
if (request.contents.length === 0) {
Expand Down
12 changes: 8 additions & 4 deletions packages/vertexai/src/methods/count-tokens.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import { countTokens } from './count-tokens';
import { CountTokensRequest } from '../types';
import { ApiSettings } from '../types/internal';
import { Task } from '../requests/request';
import { ChromeAdapter } from './chrome-adapter';

use(sinonChai);
use(chaiAsPromised);
Expand Down Expand Up @@ -52,7 +53,8 @@ describe('countTokens()', () => {
const result = await countTokens(
fakeApiSettings,
'model',
fakeRequestParams
fakeRequestParams,
new ChromeAdapter()
);
expect(result.totalTokens).to.equal(6);
expect(result.totalBillableCharacters).to.equal(16);
Expand All @@ -77,7 +79,8 @@ describe('countTokens()', () => {
const result = await countTokens(
fakeApiSettings,
'model',
fakeRequestParams
fakeRequestParams,
new ChromeAdapter()
);
expect(result.totalTokens).to.equal(1837);
expect(result.totalBillableCharacters).to.equal(117);
Expand All @@ -104,7 +107,8 @@ describe('countTokens()', () => {
const result = await countTokens(
fakeApiSettings,
'model',
fakeRequestParams
fakeRequestParams,
new ChromeAdapter()
);
expect(result.totalTokens).to.equal(258);
expect(result).to.not.have.property('totalBillableCharacters');
Expand All @@ -127,7 +131,7 @@ describe('countTokens()', () => {
json: mockResponse.json
} as Response);
await expect(
countTokens(fakeApiSettings, 'model', fakeRequestParams)
countTokens(fakeApiSettings, 'model', fakeRequestParams, new ChromeAdapter())
).to.be.rejectedWith(/404.*not found/);
expect(mockFetch).to.be.called;
});
Expand Down
17 changes: 16 additions & 1 deletion packages/vertexai/src/methods/count-tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ import {
} from '../types';
import { Task, makeRequest } from '../requests/request';
import { ApiSettings } from '../types/internal';
import { ChromeAdapter } from './chrome-adapter';

export async function countTokens(
export async function countTokensOnCloud(
apiSettings: ApiSettings,
model: string,
params: CountTokensRequest,
Expand All @@ -39,3 +40,17 @@ export async function countTokens(
);
return response.json();
}

export async function countTokens(
apiSettings: ApiSettings,
model: string,
params: CountTokensRequest,
chromeAdapter: ChromeAdapter,
requestOptions?: RequestOptions
): Promise<CountTokensResponse> {
if (await chromeAdapter.isAvailable(params)) {
return (await chromeAdapter.countTokens(params)).json();
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: unnecessary else, given the return statement above

return countTokensOnCloud(apiSettings, model, params, requestOptions);
}
}
2 changes: 1 addition & 1 deletion packages/vertexai/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,6 @@ export class GenerativeModel extends VertexAIModel {
request: CountTokensRequest | string | Array<string | Part>
): Promise<CountTokensResponse> {
const formattedParams = formatGenerateContentInput(request);
return countTokens(this._apiSettings, this.model, formattedParams);
return countTokens(this._apiSettings, this.model, formattedParams, this.chromeAdapter);
}
}