Skip to content

Commit 5fa5302

Browse files
russellwheatleymikehardy
authored andcommitted
test: googleai-mapper
1 parent eff7735 commit 5fa5302

File tree

1 file changed

+358
-0
lines changed

1 file changed

+358
-0
lines changed
Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
/**
2+
* @license
3+
* Copyright 2024 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
import { describe, it, expect, beforeEach, afterEach, jest } from '@jest/globals';
18+
import { AIError } from 'lib';
19+
import {
20+
mapCountTokensRequest,
21+
mapGenerateContentCandidates,
22+
mapGenerateContentRequest,
23+
mapGenerateContentResponse,
24+
mapPromptFeedback,
25+
} from 'lib/googleai-mappers';
26+
import {
27+
AIErrorCode,
28+
BlockReason,
29+
CountTokensRequest,
30+
Content,
31+
FinishReason,
32+
GenerateContentRequest,
33+
GoogleAICountTokensRequest,
34+
GoogleAIGenerateContentCandidate,
35+
GoogleAIGenerateContentResponse,
36+
HarmBlockMethod,
37+
HarmBlockThreshold,
38+
HarmCategory,
39+
HarmProbability,
40+
HarmSeverity,
41+
PromptFeedback,
42+
SafetyRating,
43+
} from 'lib/public-types';
44+
import { getMockResponse } from './test-utils/mock-response';
45+
import { SpiedFunction } from 'jest-mock';
46+
47+
const fakeModel = 'models/gemini-pro';
48+
49+
const fakeContents: Content[] = [{ role: 'user', parts: [{ text: 'hello' }] }];
50+
51+
describe('Google AI Mappers', () => {
52+
let loggerWarnSpy: SpiedFunction<{
53+
(message?: any, ...optionalParams: any[]): void;
54+
(message?: any, ...optionalParams: any[]): void;
55+
}>;
56+
57+
beforeEach(() => {
58+
loggerWarnSpy = jest.spyOn(console, 'warn').mockImplementation(() => {});
59+
});
60+
61+
afterEach(() => {
62+
jest.restoreAllMocks();
63+
});
64+
65+
describe('mapGenerateContentRequest', () => {
66+
it('should throw if safetySettings contain method', () => {
67+
const request: GenerateContentRequest = {
68+
contents: fakeContents,
69+
safetySettings: [
70+
{
71+
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
72+
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
73+
method: HarmBlockMethod.SEVERITY,
74+
},
75+
],
76+
};
77+
const error = new AIError(
78+
AIErrorCode.UNSUPPORTED,
79+
'SafetySettings.method is not supported in requests to the Gemini Developer API',
80+
);
81+
expect(() => mapGenerateContentRequest(request)).toThrowError(error);
82+
});
83+
84+
it('should warn and round topK if present', () => {
85+
const request: GenerateContentRequest = {
86+
contents: fakeContents,
87+
generationConfig: {
88+
topK: 15.7,
89+
},
90+
};
91+
const mappedRequest = mapGenerateContentRequest(request);
92+
expect(loggerWarnSpy).toHaveBeenCalledWith(
93+
'topK in GenerationConfig has been rounded to the nearest integer to match the format for requests to the Gemini Developer API.',
94+
);
95+
expect(mappedRequest.generationConfig?.topK).toBe(16);
96+
});
97+
98+
it('should not modify topK if it is already an integer', () => {
99+
const request: GenerateContentRequest = {
100+
contents: fakeContents,
101+
generationConfig: {
102+
topK: 16,
103+
},
104+
};
105+
const mappedRequest = mapGenerateContentRequest(request);
106+
expect(loggerWarnSpy).not.toHaveBeenCalled();
107+
expect(mappedRequest.generationConfig?.topK).toBe(16);
108+
});
109+
110+
it('should return the request mostly unchanged if valid', () => {
111+
const request: GenerateContentRequest = {
112+
contents: fakeContents,
113+
safetySettings: [
114+
{
115+
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
116+
threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
117+
},
118+
],
119+
generationConfig: {
120+
temperature: 0.5,
121+
},
122+
};
123+
const mappedRequest = mapGenerateContentRequest({ ...request });
124+
expect(mappedRequest).toEqual(request);
125+
expect(loggerWarnSpy).not.toHaveBeenCalled();
126+
});
127+
});
128+
129+
describe('mapGenerateContentResponse', () => {
130+
it('should map a full Google AI response', async () => {
131+
const googleAIMockResponse: GoogleAIGenerateContentResponse = await (
132+
getMockResponse('unary-success-citations.json') as Response
133+
).json();
134+
const mappedResponse = mapGenerateContentResponse(googleAIMockResponse);
135+
136+
expect(mappedResponse.candidates).toBeDefined();
137+
expect(mappedResponse.candidates?.[0]?.content.parts[0]?.text).toContain('quantum mechanics');
138+
139+
// Mapped citations
140+
expect(mappedResponse.candidates?.[0]?.citationMetadata?.citations[0]?.startIndex).toBe(
141+
googleAIMockResponse.candidates?.[0]?.citationMetadata?.citationSources[0]?.startIndex,
142+
);
143+
expect(mappedResponse.candidates?.[0]?.citationMetadata?.citations[0]?.endIndex).toBe(
144+
googleAIMockResponse.candidates?.[0]?.citationMetadata?.citationSources[0]?.endIndex,
145+
);
146+
147+
// Mapped safety ratings
148+
expect(mappedResponse.candidates?.[0]?.safetyRatings?.[0]?.probabilityScore).toBe(0);
149+
expect(mappedResponse.candidates?.[0]?.safetyRatings?.[0]?.severityScore).toBe(0);
150+
expect(mappedResponse.candidates?.[0]?.safetyRatings?.[0]?.severity).toBe(
151+
HarmSeverity.HARM_SEVERITY_UNSUPPORTED,
152+
);
153+
154+
expect(mappedResponse.candidates?.[0]?.finishReason).toBe(FinishReason.STOP);
155+
156+
// Check usage metadata passthrough
157+
expect(mappedResponse.usageMetadata).toEqual(googleAIMockResponse.usageMetadata);
158+
});
159+
160+
it('should handle missing candidates and promptFeedback', () => {
161+
const googleAIResponse: GoogleAIGenerateContentResponse = {
162+
// No candidates
163+
// No promptFeedback
164+
usageMetadata: {
165+
promptTokenCount: 5,
166+
candidatesTokenCount: 0,
167+
totalTokenCount: 5,
168+
},
169+
};
170+
const mappedResponse = mapGenerateContentResponse(googleAIResponse);
171+
expect(mappedResponse.candidates).toBeUndefined();
172+
expect(mappedResponse.promptFeedback).toBeUndefined(); // Mapped to undefined
173+
expect(mappedResponse.usageMetadata).toEqual(googleAIResponse.usageMetadata);
174+
});
175+
176+
it('should handle empty candidates array', () => {
177+
const googleAIResponse: GoogleAIGenerateContentResponse = {
178+
candidates: [],
179+
usageMetadata: {
180+
promptTokenCount: 5,
181+
candidatesTokenCount: 0,
182+
totalTokenCount: 5,
183+
},
184+
};
185+
const mappedResponse = mapGenerateContentResponse(googleAIResponse);
186+
expect(mappedResponse.candidates).toEqual([]);
187+
expect(mappedResponse.promptFeedback).toBeUndefined();
188+
expect(mappedResponse.usageMetadata).toEqual(googleAIResponse.usageMetadata);
189+
});
190+
});
191+
192+
describe('mapCountTokensRequest', () => {
193+
it('should map a Vertex AI CountTokensRequest to Google AI format', () => {
194+
const vertexRequest: CountTokensRequest = {
195+
contents: fakeContents,
196+
systemInstruction: { role: 'system', parts: [{ text: 'Be nice' }] },
197+
tools: [{ functionDeclarations: [{ name: 'foo', description: 'bar' }] }],
198+
generationConfig: { temperature: 0.8 },
199+
};
200+
201+
const expectedGoogleAIRequest: GoogleAICountTokensRequest = {
202+
generateContentRequest: {
203+
model: fakeModel,
204+
contents: vertexRequest.contents,
205+
systemInstruction: vertexRequest.systemInstruction,
206+
tools: vertexRequest.tools,
207+
generationConfig: vertexRequest.generationConfig,
208+
},
209+
};
210+
211+
const mappedRequest = mapCountTokensRequest(vertexRequest, fakeModel);
212+
expect(mappedRequest).toEqual(expectedGoogleAIRequest);
213+
});
214+
215+
it('should map a minimal Vertex AI CountTokensRequest', () => {
216+
const vertexRequest: CountTokensRequest = {
217+
contents: fakeContents,
218+
systemInstruction: { role: 'system', parts: [{ text: 'Be nice' }] },
219+
generationConfig: { temperature: 0.8 },
220+
};
221+
222+
const expectedGoogleAIRequest: GoogleAICountTokensRequest = {
223+
generateContentRequest: {
224+
model: fakeModel,
225+
contents: vertexRequest.contents,
226+
systemInstruction: { role: 'system', parts: [{ text: 'Be nice' }] },
227+
generationConfig: { temperature: 0.8 },
228+
},
229+
};
230+
231+
const mappedRequest = mapCountTokensRequest(vertexRequest, fakeModel);
232+
expect(mappedRequest).toEqual(expectedGoogleAIRequest);
233+
});
234+
});
235+
236+
describe('mapGenerateContentCandidates', () => {
237+
it('should map citationSources to citationMetadata.citations', () => {
238+
const candidates: GoogleAIGenerateContentCandidate[] = [
239+
{
240+
index: 0,
241+
content: { role: 'model', parts: [{ text: 'Cited text' }] },
242+
citationMetadata: {
243+
citationSources: [
244+
{ startIndex: 0, endIndex: 5, uri: 'uri1', license: 'MIT' },
245+
{ startIndex: 6, endIndex: 10, uri: 'uri2' },
246+
],
247+
},
248+
},
249+
];
250+
const mapped = mapGenerateContentCandidates(candidates);
251+
expect(mapped[0]?.citationMetadata).toBeDefined();
252+
expect(mapped[0]?.citationMetadata?.citations).toEqual(
253+
candidates[0]?.citationMetadata?.citationSources,
254+
);
255+
expect(mapped[0]?.citationMetadata?.citations[0]?.title).toBeUndefined(); // Not in Google AI
256+
expect(mapped[0]?.citationMetadata?.citations[0]?.publicationDate).toBeUndefined(); // Not in Google AI
257+
});
258+
259+
it('should add default safety rating properties', () => {
260+
const candidates: GoogleAIGenerateContentCandidate[] = [
261+
{
262+
index: 0,
263+
content: { role: 'model', parts: [{ text: 'Maybe unsafe' }] },
264+
safetyRatings: [
265+
{
266+
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
267+
probability: HarmProbability.MEDIUM,
268+
blocked: false,
269+
// Missing severity, probabilityScore, severityScore
270+
} as any,
271+
],
272+
},
273+
];
274+
const mapped = mapGenerateContentCandidates(candidates);
275+
expect(mapped[0]?.safetyRatings).toBeDefined();
276+
const safetyRating = mapped[0]?.safetyRatings?.[0] as SafetyRating; // Type assertion
277+
expect(safetyRating.severity).toBe(HarmSeverity.HARM_SEVERITY_UNSUPPORTED);
278+
expect(safetyRating.probabilityScore).toBe(0);
279+
expect(safetyRating.severityScore).toBe(0);
280+
// Existing properties should be preserved
281+
expect(safetyRating.category).toBe(HarmCategory.HARM_CATEGORY_HARASSMENT);
282+
expect(safetyRating.probability).toBe(HarmProbability.MEDIUM);
283+
expect(safetyRating.blocked).toBe(false);
284+
});
285+
286+
it('should throw if videoMetadata is present in parts', () => {
287+
const candidates: GoogleAIGenerateContentCandidate[] = [
288+
{
289+
index: 0,
290+
content: {
291+
role: 'model',
292+
parts: [
293+
{
294+
inlineData: { mimeType: 'video/mp4', data: 'base64==' },
295+
videoMetadata: { startOffset: '0s', endOffset: '5s' }, // Unsupported
296+
},
297+
],
298+
},
299+
},
300+
];
301+
expect(() => mapGenerateContentCandidates(candidates)).toThrowError(
302+
new AIError(AIErrorCode.UNSUPPORTED, 'Part.videoMetadata is not supported'),
303+
);
304+
});
305+
306+
it('should handle candidates without citation or safety ratings', () => {
307+
const candidates: GoogleAIGenerateContentCandidate[] = [
308+
{
309+
index: 0,
310+
content: { role: 'model', parts: [{ text: 'Simple text' }] },
311+
finishReason: FinishReason.STOP,
312+
},
313+
];
314+
const mapped = mapGenerateContentCandidates(candidates);
315+
expect(mapped[0]?.citationMetadata).toBeUndefined();
316+
expect(mapped[0]?.safetyRatings).toBeUndefined();
317+
expect(mapped[0]?.content?.parts[0]?.text).toBe('Simple text');
318+
expect(loggerWarnSpy).not.toHaveBeenCalled();
319+
});
320+
321+
it('should handle empty candidate array', () => {
322+
const candidates: GoogleAIGenerateContentCandidate[] = [];
323+
const mapped = mapGenerateContentCandidates(candidates);
324+
expect(mapped).toEqual([]);
325+
expect(loggerWarnSpy).not.toHaveBeenCalled();
326+
});
327+
});
328+
329+
describe('mapPromptFeedback', () => {
330+
it('should add default safety rating properties', () => {
331+
const feedback: PromptFeedback = {
332+
blockReason: BlockReason.OTHER,
333+
safetyRatings: [
334+
{
335+
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
336+
probability: HarmProbability.HIGH,
337+
blocked: true,
338+
// Missing severity, probabilityScore, severityScore
339+
} as any,
340+
],
341+
// Missing blockReasonMessage
342+
};
343+
const mapped = mapPromptFeedback(feedback);
344+
expect(mapped.safetyRatings).toBeDefined();
345+
const safetyRating = mapped.safetyRatings[0] as SafetyRating; // Type assertion
346+
expect(safetyRating.severity).toBe(HarmSeverity.HARM_SEVERITY_UNSUPPORTED);
347+
expect(safetyRating.probabilityScore).toBe(0);
348+
expect(safetyRating.severityScore).toBe(0);
349+
// Existing properties should be preserved
350+
expect(safetyRating.category).toBe(HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT);
351+
expect(safetyRating.probability).toBe(HarmProbability.HIGH);
352+
expect(safetyRating.blocked).toBe(true);
353+
// Other properties
354+
expect(mapped.blockReason).toBe(BlockReason.OTHER);
355+
expect(mapped.blockReasonMessage).toBeUndefined(); // Not present in input
356+
});
357+
});
358+
});

0 commit comments

Comments
 (0)