Skip to content

Commit b57874d

Browse files
test(ai): count tokens unit tests
1 parent 0f8b8b2 commit b57874d

File tree

1 file changed

+82
-6
lines changed

1 file changed

+82
-6
lines changed

packages/ai/__tests__/count-tokens.test.ts

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
import { describe, expect, it, afterEach, jest } from '@jest/globals';
18-
import { getMockResponse } from './test-utils/mock-response';
17+
import { describe, expect, it, afterEach, jest, beforeEach } from '@jest/globals';
18+
import { BackendName, getMockResponse } from './test-utils/mock-response';
1919
import * as request from '../lib/requests/request';
2020
import { countTokens } from '../lib/methods/count-tokens';
21-
import { CountTokensRequest } from '../lib/types';
21+
import { CountTokensRequest, RequestOptions } from '../lib/types';
2222
import { ApiSettings } from '../lib/types/internal';
2323
import { Task } from '../lib/requests/request';
2424
import { GoogleAIBackend } from '../lib/backend';
25+
import { SpiedFunction } from 'jest-mock';
26+
import { mapCountTokensRequest } from '../lib/googleai-mappers';
2527

2628
const fakeApiSettings: ApiSettings = {
2729
apiKey: 'key',
@@ -31,6 +33,14 @@ const fakeApiSettings: ApiSettings = {
3133
backend: new GoogleAIBackend(),
3234
};
3335

36+
const fakeGoogleAIApiSettings: ApiSettings = {
37+
apiKey: 'key',
38+
project: 'my-project',
39+
appId: 'my-appid',
40+
location: '',
41+
backend: new GoogleAIBackend(),
42+
};
43+
3444
const fakeRequestParams: CountTokensRequest = {
3545
contents: [{ parts: [{ text: 'hello' }], role: 'user' }],
3646
};
@@ -41,7 +51,7 @@ describe('countTokens()', () => {
4151
});
4252

4353
it('total tokens', async () => {
44-
const mockResponse = getMockResponse('unary-success-total-tokens.json');
54+
const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-success-total-tokens.json');
4555
const makeRequestStub = jest
4656
.spyOn(request, 'makeRequest')
4757
.mockResolvedValue(mockResponse as Response);
@@ -58,8 +68,35 @@ describe('countTokens()', () => {
5868
);
5969
});
6070

71+
it('total tokens with modality details', async () => {
72+
const mockResponse = getMockResponse(
73+
BackendName.VertexAI,
74+
'unary-success-detailed-token-response.json',
75+
);
76+
const makeRequestStub = jest
77+
.spyOn(request, 'makeRequest')
78+
.mockResolvedValue(mockResponse as Response);
79+
const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams);
80+
81+
expect(result.totalTokens).toBe(1837);
82+
expect(result.totalBillableCharacters).toBe(117);
83+
expect(result.promptTokensDetails?.[0]?.modality).toBe('IMAGE');
84+
expect(result.promptTokensDetails?.[0]?.tokenCount).toBe(1806);
85+
expect(makeRequestStub).toHaveBeenCalledWith(
86+
'model',
87+
Task.COUNT_TOKENS,
88+
fakeApiSettings,
89+
false,
90+
expect.stringContaining('contents'),
91+
undefined,
92+
);
93+
});
94+
6195
it('total tokens no billable characters', async () => {
62-
const mockResponse = getMockResponse('unary-success-no-billable-characters.json');
96+
const mockResponse = getMockResponse(
97+
BackendName.VertexAI,
98+
'unary-success-no-billable-characters.json',
99+
);
63100
const makeRequestStub = jest
64101
.spyOn(request, 'makeRequest')
65102
.mockResolvedValue(mockResponse as Response);
@@ -77,7 +114,10 @@ describe('countTokens()', () => {
77114
});
78115

79116
it('model not found', async () => {
80-
const mockResponse = getMockResponse('unary-failure-model-not-found.json');
117+
const mockResponse = getMockResponse(
118+
BackendName.VertexAI,
119+
'unary-failure-model-not-found.json',
120+
);
81121
const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({
82122
ok: false,
83123
status: 404,
@@ -88,4 +128,40 @@ describe('countTokens()', () => {
88128
);
89129
expect(mockFetch).toHaveBeenCalled();
90130
});
131+
132+
describe('googleAI', () => {
133+
let makeRequestStub: SpiedFunction<
134+
(
135+
model: string,
136+
task: Task,
137+
apiSettings: ApiSettings,
138+
stream: boolean,
139+
body: string,
140+
requestOptions?: RequestOptions,
141+
) => Promise<Response>
142+
>;
143+
144+
beforeEach(() => {
145+
makeRequestStub = jest.spyOn(request, 'makeRequest');
146+
});
147+
148+
afterEach(() => {
149+
jest.restoreAllMocks();
150+
});
151+
152+
it('maps request to GoogleAI format', async () => {
153+
makeRequestStub.mockResolvedValue({ ok: true, json: () => {} } as Response); // Unused
154+
155+
await countTokens(fakeGoogleAIApiSettings, 'model', fakeRequestParams);
156+
157+
expect(makeRequestStub).toHaveBeenCalledWith(
158+
'model',
159+
Task.COUNT_TOKENS,
160+
fakeGoogleAIApiSettings,
161+
false,
162+
JSON.stringify(mapCountTokensRequest(fakeRequestParams, 'model')),
163+
undefined,
164+
);
165+
});
166+
});
91167
});

0 commit comments

Comments
 (0)