Skip to content

Commit cf9e6cd

Browse files
russellwheatleymikehardy
authored andcommitted
test(ai): update unit test for firebase ai
1 parent c0b4e9f commit cf9e6cd

11 files changed

+258
-114
lines changed

packages/ai/__tests__/api.test.ts

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,19 @@
1515
* limitations under the License.
1616
*/
1717
import { describe, expect, it } from '@jest/globals';
18-
import { getApp, type ReactNativeFirebase } from '../../app/lib';
18+
import { type ReactNativeFirebase } from '../../app/lib';
1919

20-
import { ModelParams, VertexAIErrorCode } from '../lib/types';
21-
import { VertexAIError } from '../lib/errors';
22-
import { getGenerativeModel, getVertexAI } from '../lib/index';
20+
import { ModelParams, AIErrorCode } from '../lib/types';
21+
import { AIError } from '../lib/errors';
22+
import { getGenerativeModel } from '../lib/index';
2323

24-
import { VertexAI } from '../lib/public-types';
24+
import { AI } from '../lib/public-types';
2525
import { GenerativeModel } from '../lib/models/generative-model';
2626

27-
import '../../auth/lib';
28-
import '../../app-check/lib';
29-
import { getAuth } from '../../auth/lib';
27+
import { AI_TYPE } from '../lib/constants';
28+
import { VertexAIBackend } from '../lib/backend';
3029

31-
const fakeVertexAI: VertexAI = {
30+
const fakeAI: AI = {
3231
app: {
3332
name: 'DEFAULT',
3433
options: {
@@ -37,66 +36,76 @@ const fakeVertexAI: VertexAI = {
3736
projectId: 'my-project',
3837
},
3938
} as ReactNativeFirebase.FirebaseApp,
39+
backend: new VertexAIBackend('us-central1'),
4040
location: 'us-central1',
4141
};
4242

4343
describe('Top level API', () => {
44-
it('should allow auth and app check instances to be passed in', () => {
45-
const app = getApp();
46-
const auth = getAuth();
47-
const appCheck = app.appCheck();
48-
49-
getVertexAI(app, { appCheck, auth });
50-
});
51-
5244
it('getGenerativeModel throws if no model is provided', () => {
5345
try {
54-
getGenerativeModel(fakeVertexAI, {} as ModelParams);
46+
getGenerativeModel(fakeAI, {} as ModelParams);
5547
} catch (e) {
56-
expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_MODEL);
57-
expect((e as VertexAIError).message).toContain(
48+
expect((e as AIError).code).toContain(AIErrorCode.NO_MODEL);
49+
expect((e as AIError).message).toContain(
5850
`VertexAI: Must provide a model name. Example: ` +
59-
`getGenerativeModel({ model: 'my-model-name' }) (vertexAI/${VertexAIErrorCode.NO_MODEL})`,
51+
`getGenerativeModel({ model: 'my-model-name' }) (vertexAI/${AIErrorCode.NO_MODEL})`,
6052
);
6153
}
6254
});
6355

6456
it('getGenerativeModel throws if no apiKey is provided', () => {
6557
const fakeVertexNoApiKey = {
66-
...fakeVertexAI,
67-
app: { options: { projectId: 'my-project' } },
68-
} as VertexAI;
58+
...fakeAI,
59+
app: { options: { projectId: 'my-project', appId: 'my-appid' } },
60+
} as AI;
6961
try {
7062
getGenerativeModel(fakeVertexNoApiKey, { model: 'my-model' });
7163
} catch (e) {
72-
expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_API_KEY);
73-
expect((e as VertexAIError).message).toBe(
64+
expect((e as AIError).code).toContain(AIErrorCode.NO_API_KEY);
65+
expect((e as AIError).message).toBe(
7466
`VertexAI: The "apiKey" field is empty in the local ` +
7567
`Firebase config. Firebase VertexAI requires this field to` +
76-
` contain a valid API key. (vertexAI/${VertexAIErrorCode.NO_API_KEY})`,
68+
` contain a valid API key. (vertexAI/${AIErrorCode.NO_API_KEY})`,
7769
);
7870
}
7971
});
8072

8173
it('getGenerativeModel throws if no projectId is provided', () => {
8274
const fakeVertexNoProject = {
83-
...fakeVertexAI,
75+
...fakeAI,
8476
app: { options: { apiKey: 'my-key' } },
85-
} as VertexAI;
77+
} as AI;
8678
try {
8779
getGenerativeModel(fakeVertexNoProject, { model: 'my-model' });
8880
} catch (e) {
89-
expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_PROJECT_ID);
90-
expect((e as VertexAIError).message).toBe(
81+
expect((e as AIError).code).toContain(AIErrorCode.NO_PROJECT_ID);
82+
expect((e as AIError).message).toBe(
9183
`VertexAI: The "projectId" field is empty in the local` +
9284
` Firebase config. Firebase VertexAI requires this field ` +
93-
`to contain a valid project ID. (vertexAI/${VertexAIErrorCode.NO_PROJECT_ID})`,
85+
`to contain a valid project ID. (vertexAI/${AIErrorCode.NO_PROJECT_ID})`,
86+
);
87+
}
88+
});
89+
90+
it('getGenerativeModel throws if no appId is provided', () => {
91+
const fakeVertexNoProject = {
92+
...fakeAI,
93+
app: { options: { apiKey: 'my-key', projectId: 'my-projectid' } },
94+
} as AI;
95+
try {
96+
getGenerativeModel(fakeVertexNoProject, { model: 'my-model' });
97+
} catch (e) {
98+
expect((e as AIError).code).toContain(AIErrorCode.NO_APP_ID);
99+
expect((e as AIError).message).toBe(
100+
`AI: The "appId" field is empty in the local` +
101+
` Firebase config. Firebase AI requires this field ` +
102+
`to contain a valid app ID. (${AI_TYPE}/${AIErrorCode.NO_APP_ID})`,
94103
);
95104
}
96105
});
97106

98107
it('getGenerativeModel gets a GenerativeModel', () => {
99-
const genModel = getGenerativeModel(fakeVertexAI, { model: 'my-model' });
108+
const genModel = getGenerativeModel(fakeAI, { model: 'my-model' });
100109
expect(genModel).toBeInstanceOf(GenerativeModel);
101110
expect(genModel.model).toBe('publishers/google/models/my-model');
102111
});

packages/ai/__tests__/backend.test.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
*/
1717
import { describe, it, expect } from '@jest/globals';
1818
import { GoogleAIBackend, VertexAIBackend } from '../lib/backend';
19-
import { BackendType } from 'lib/public-types';
20-
import { DEFAULT_LOCATION } from 'lib/constants';
19+
import { BackendType } from '../lib/public-types';
20+
import { DEFAULT_LOCATION } from '../lib/constants';
21+
2122
describe('Backend', () => {
2223
describe('GoogleAIBackend', () => {
2324
it('sets backendType to GOOGLE_AI', () => {

packages/ai/__tests__/chat-session.test.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@ import { GenerateContentStreamResult } from '../lib/types';
2121
import { ChatSession } from '../lib/methods/chat-session';
2222
import { ApiSettings } from '../lib/types/internal';
2323
import { RequestOptions } from '../lib/types/requests';
24+
import { VertexAIBackend } from '../lib/backend';
2425

2526
const fakeApiSettings: ApiSettings = {
2627
apiKey: 'key',
2728
project: 'my-project',
29+
appId: 'my-appid',
2830
location: 'us-central1',
31+
backend: new VertexAIBackend(),
2932
};
3033

3134
const requestOptions: RequestOptions = {

packages/ai/__tests__/count-tokens.test.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@ import { countTokens } from '../lib/methods/count-tokens';
2121
import { CountTokensRequest } from '../lib/types';
2222
import { ApiSettings } from '../lib/types/internal';
2323
import { Task } from '../lib/requests/request';
24+
import { GoogleAIBackend } from '../lib/backend';
2425

2526
const fakeApiSettings: ApiSettings = {
2627
apiKey: 'key',
2728
project: 'my-project',
2829
location: 'us-central1',
30+
appId: '',
31+
backend: new GoogleAIBackend(),
2932
};
3033

3134
const fakeRequestParams: CountTokensRequest = {

packages/ai/__tests__/generate-content.test.ts

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
import { describe, expect, it, afterEach, jest } from '@jest/globals';
17+
import { describe, expect, it, afterEach, jest, beforeEach } from '@jest/globals';
1818
import { getMockResponse } from './test-utils/mock-response';
1919
import * as request from '../lib/requests/request';
2020
import { generateContent } from '../lib/methods/generate-content';
@@ -27,11 +27,25 @@ import {
2727
} from '../lib/types';
2828
import { ApiSettings } from '../lib/types/internal';
2929
import { Task } from '../lib/requests/request';
30+
import { GoogleAIBackend, VertexAIBackend } from '../lib/backend';
31+
import { SpiedFunction } from 'jest-mock';
32+
import { AIError } from '../lib/errors';
33+
import { mapGenerateContentRequest } from '../lib/googleai-mappers';
3034

3135
const fakeApiSettings: ApiSettings = {
3236
apiKey: 'key',
3337
project: 'my-project',
38+
appId: 'my-appid',
3439
location: 'us-central1',
40+
backend: new VertexAIBackend(),
41+
};
42+
43+
const fakeGoogleAIApiSettings: ApiSettings = {
44+
apiKey: 'key',
45+
project: 'my-project',
46+
appId: 'my-appid',
47+
location: 'us-central1',
48+
backend: new GoogleAIBackend(),
3549
};
3650

3751
const fakeRequestParams: GenerateContentRequest = {
@@ -48,6 +62,19 @@ const fakeRequestParams: GenerateContentRequest = {
4862
],
4963
};
5064

65+
const fakeGoogleAIRequestParams: GenerateContentRequest = {
66+
contents: [{ parts: [{ text: 'hello' }], role: 'user' }],
67+
generationConfig: {
68+
topK: 16,
69+
},
70+
safetySettings: [
71+
{
72+
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
73+
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
74+
},
75+
],
76+
};
77+
5178
describe('generateContent()', () => {
5279
afterEach(() => {
5380
jest.restoreAllMocks();
@@ -88,6 +115,28 @@ describe('generateContent()', () => {
88115
);
89116
});
90117

118+
it('long response with token details', async () => {
119+
const mockResponse = getMockResponse('unary-success-basic-response-long-usage-metadata.json');
120+
const makeRequestStub = jest
121+
.spyOn(request, 'makeRequest')
122+
.mockResolvedValue(mockResponse as Response);
123+
const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams);
124+
expect(result.response.usageMetadata?.totalTokenCount).toEqual(1913);
125+
expect(result.response.usageMetadata?.candidatesTokenCount).toEqual(76);
126+
expect(result.response.usageMetadata?.promptTokensDetails?.[0]?.modality).toEqual('IMAGE');
127+
expect(result.response.usageMetadata?.promptTokensDetails?.[0]?.tokenCount).toEqual(1806);
128+
expect(result.response.usageMetadata?.candidatesTokensDetails?.[0]?.modality).toEqual('TEXT');
129+
expect(result.response.usageMetadata?.candidatesTokensDetails?.[0]?.tokenCount).toEqual(76);
130+
expect(makeRequestStub).toHaveBeenCalledWith(
131+
'model',
132+
Task.GENERATE_CONTENT,
133+
fakeApiSettings,
134+
false,
135+
expect.anything(),
136+
undefined,
137+
);
138+
});
139+
91140
it('citations', async () => {
92141
const mockResponse = getMockResponse('unary-success-citations.json');
93142
const makeRequestStub = jest
@@ -201,4 +250,54 @@ describe('generateContent()', () => {
201250
);
202251
expect(mockFetch).toHaveBeenCalled();
203252
});
253+
254+
describe('googleAI', () => {
255+
let makeRequestStub: SpiedFunction<typeof request.makeRequest>;
256+
257+
beforeEach(() => {
258+
makeRequestStub = jest.spyOn(request, 'makeRequest');
259+
});
260+
261+
afterEach(() => {
262+
jest.restoreAllMocks();
263+
});
264+
265+
it('throws error when method is defined', async () => {
266+
const mockResponse = getMockResponse('unary-success-basic-reply-short.txt');
267+
makeRequestStub.mockResolvedValue(mockResponse as Response);
268+
269+
const requestParamsWithMethod: GenerateContentRequest = {
270+
contents: [{ parts: [{ text: 'hello' }], role: 'user' }],
271+
safetySettings: [
272+
{
273+
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
274+
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
275+
method: HarmBlockMethod.SEVERITY, // Unsupported in Google AI.
276+
},
277+
],
278+
};
279+
280+
// Expect generateContent to throw a AIError that method is not supported.
281+
await expect(
282+
generateContent(fakeGoogleAIApiSettings, 'model', requestParamsWithMethod),
283+
).rejects.toThrow(AIError);
284+
expect(makeRequestStub).not.toHaveBeenCalled();
285+
});
286+
287+
it('maps request to GoogleAI format', async () => {
288+
const mockResponse = getMockResponse('unary-success-basic-reply-short.txt');
289+
makeRequestStub.mockResolvedValue(mockResponse as Response);
290+
291+
await generateContent(fakeGoogleAIApiSettings, 'model', fakeGoogleAIRequestParams);
292+
293+
expect(makeRequestStub).toHaveBeenCalledWith(
294+
'model',
295+
Task.GENERATE_CONTENT,
296+
fakeGoogleAIApiSettings,
297+
false,
298+
JSON.stringify(mapGenerateContentRequest(fakeGoogleAIRequestParams)),
299+
undefined,
300+
);
301+
});
302+
});
204303
});

0 commit comments

Comments
 (0)