Skip to content

Commit 65a5494

Browse files
committed
add suppport for grounding in gemini
1 parent 5f461e5 commit 65a5494

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

src/providers/google-vertex-ai/chatComplete.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import type {
4141
} from './types';
4242
import { getMimeType } from './utils';
4343

44-
const buildGoogleSearchRetrievalTool = (tool: Tool) => {
44+
export const buildGoogleSearchRetrievalTool = (tool: Tool) => {
4545
const googleSearchRetrievalTool: GoogleSearchRetrievalTool = {
4646
googleSearchRetrieval: {},
4747
};

src/providers/google/chatComplete.ts

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
ToolCall,
88
ToolChoice,
99
} from '../../types/requestBody';
10+
import { buildGoogleSearchRetrievalTool } from '../google-vertex-ai/chatComplete';
1011
import { derefer, getMimeType } from '../google-vertex-ai/utils';
1112
import {
1213
ChatCompletionResponse,
@@ -325,12 +326,20 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
325326
default: '',
326327
transform: (params: Params) => {
327328
const functionDeclarations: any = [];
329+
const tools: any = [];
328330
params.tools?.forEach((tool) => {
329331
if (tool.type === 'function') {
330-
functionDeclarations.push(tool.function);
332+
if (tool.function.name === 'googleSearchRetrieval') {
333+
tools.push(buildGoogleSearchRetrievalTool(tool));
334+
} else {
335+
functionDeclarations.push(tool.function);
336+
}
331337
}
332338
});
333-
return { functionDeclarations };
339+
if (functionDeclarations.length) {
340+
tools.push({ functionDeclarations });
341+
}
342+
return tools;
334343
},
335344
},
336345
tool_choice: {
@@ -388,6 +397,24 @@ interface GoogleGenerateContentResponse {
388397
category: string;
389398
probability: string;
390399
}[];
400+
groundingMetadata?: {
401+
webSearchQueries?: string[];
402+
searchEntryPoint?: {
403+
renderedContent: string;
404+
};
405+
groundingSupports?: Array<{
406+
segment: {
407+
startIndex: number;
408+
endIndex: number;
409+
text: string;
410+
};
411+
groundingChunkIndices: number[];
412+
confidenceScores: number[];
413+
}>;
414+
retrievalMetadata?: {
415+
webDynamicRetrievalScore: number;
416+
};
417+
};
391418
}[];
392419
promptFeedback: {
393420
safetyRatings: {
@@ -423,8 +450,15 @@ export const GoogleErrorResponseTransform: (
423450

424451
export const GoogleChatCompleteResponseTransform: (
425452
response: GoogleGenerateContentResponse | GoogleErrorResponse,
426-
responseStatus: number
427-
) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => {
453+
responseStatus: number,
454+
responseHeaders: Headers,
455+
strictOpenAiCompliance: boolean
456+
) => ChatCompletionResponse | ErrorResponse = (
457+
response,
458+
responseStatus,
459+
_responseHeaders,
460+
strictOpenAiCompliance
461+
) => {
428462
if (responseStatus !== 200) {
429463
const errorResponse = GoogleErrorResponseTransform(
430464
response as GoogleErrorResponse
@@ -468,6 +502,9 @@ export const GoogleChatCompleteResponseTransform: (
468502
message: message,
469503
index: generation.index ?? idx,
470504
finish_reason: generation.finishReason,
505+
...(!strictOpenAiCompliance && generation.groundingMetadata
506+
? { groundingMetadata: generation.groundingMetadata }
507+
: {}),
471508
};
472509
}) ?? [],
473510
usage: {

0 commit comments

Comments
 (0)