Skip to content

Commit 6e52f52

Browse files
authored
Merge pull request #786 from narengogi/enhancement/gemini-grounding
Enhancement: support grounding mode in gemini
2 parents 56c3573 + e12b317 commit 6e52f52

File tree

4 files changed

+110
-9
lines changed

4 files changed

+110
-9
lines changed

src/providers/google-vertex-ai/chatComplete.ts

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {
66
ContentType,
77
Message,
88
Params,
9+
Tool,
910
ToolCall,
1011
} from '../../types/requestBody';
1112
import {
@@ -36,9 +37,21 @@ import type {
3637
GoogleGenerateContentResponse,
3738
VertexLlamaChatCompleteStreamChunk,
3839
VertexLLamaChatCompleteResponse,
40+
GoogleSearchRetrievalTool,
3941
} from './types';
4042
import { getMimeType } from './utils';
4143

44+
export const buildGoogleSearchRetrievalTool = (tool: Tool) => {
45+
const googleSearchRetrievalTool: GoogleSearchRetrievalTool = {
46+
googleSearchRetrieval: {},
47+
};
48+
if (tool.function.parameters?.dynamicRetrievalConfig) {
49+
googleSearchRetrievalTool.googleSearchRetrieval.dynamicRetrievalConfig =
50+
tool.function.parameters.dynamicRetrievalConfig;
51+
}
52+
return googleSearchRetrievalTool;
53+
};
54+
4255
export const VertexGoogleChatCompleteConfig: ProviderConfig = {
4356
// https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions
4457
model: {
@@ -253,12 +266,20 @@ export const VertexGoogleChatCompleteConfig: ProviderConfig = {
253266
default: '',
254267
transform: (params: Params) => {
255268
const functionDeclarations: any = [];
269+
const tools: any = [];
256270
params.tools?.forEach((tool) => {
257271
if (tool.type === 'function') {
258-
functionDeclarations.push(tool.function);
272+
if (tool.function.name === 'googleSearchRetrieval') {
273+
tools.push(buildGoogleSearchRetrievalTool(tool));
274+
} else {
275+
functionDeclarations.push(tool.function);
276+
}
259277
}
260278
});
261-
return { functionDeclarations };
279+
if (functionDeclarations.length) {
280+
tools.push({ functionDeclarations });
281+
}
282+
return tools;
262283
},
263284
},
264285
tool_choice: {
@@ -648,6 +669,9 @@ export const GoogleChatCompleteResponseTransform: (
648669
...(!strictOpenAiCompliance && {
649670
safetyRatings: generation.safetyRatings,
650671
}),
672+
...(!strictOpenAiCompliance && generation.groundingMetadata
673+
? { groundingMetadata: generation.groundingMetadata }
674+
: {}),
651675
};
652676
}) ?? [],
653677
usage: {
@@ -778,6 +802,9 @@ export const GoogleChatCompleteStreamChunkTransform: (
778802
...(!strictOpenAiCompliance && {
779803
safetyRatings: generation.safetyRatings,
780804
}),
805+
...(!strictOpenAiCompliance && generation.groundingMetadata
806+
? { groundingMetadata: generation.groundingMetadata }
807+
: {}),
781808
};
782809
}) ?? [],
783810
usage: usageMetadata,

src/providers/google-vertex-ai/types.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ export interface GoogleGenerateContentResponse {
2828
category: string;
2929
probability: string;
3030
}[];
31+
groundingMetadata?: {
32+
webSearchQueries?: string[];
33+
searchEntryPoint?: {
34+
renderedContent: string;
35+
};
36+
groundingSupports?: Array<{
37+
segment: {
38+
startIndex: number;
39+
endIndex: number;
40+
text: string;
41+
};
42+
groundingChunkIndices: number[];
43+
confidenceScores: number[];
44+
}>;
45+
retrievalMetadata?: {
46+
webDynamicRetrievalScore: number;
47+
};
48+
};
3149
}[];
3250
promptFeedback: {
3351
safetyRatings: {
@@ -90,3 +108,12 @@ export interface GoogleEmbedResponse {
90108
billableCharacterCount: number;
91109
};
92110
}
111+
112+
export interface GoogleSearchRetrievalTool {
113+
googleSearchRetrieval: {
114+
dynamicRetrievalConfig?: {
115+
mode: string;
116+
dynamicThreshold?: string;
117+
};
118+
};
119+
}

src/providers/google/chatComplete.ts

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
ToolCall,
88
ToolChoice,
99
} from '../../types/requestBody';
10+
import { buildGoogleSearchRetrievalTool } from '../google-vertex-ai/chatComplete';
1011
import { derefer, getMimeType } from '../google-vertex-ai/utils';
1112
import {
1213
ChatCompletionResponse,
@@ -325,12 +326,20 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
325326
default: '',
326327
transform: (params: Params) => {
327328
const functionDeclarations: any = [];
329+
const tools: any = [];
328330
params.tools?.forEach((tool) => {
329331
if (tool.type === 'function') {
330-
functionDeclarations.push(tool.function);
332+
if (tool.function.name === 'googleSearchRetrieval') {
333+
tools.push(buildGoogleSearchRetrievalTool(tool));
334+
} else {
335+
functionDeclarations.push(tool.function);
336+
}
331337
}
332338
});
333-
return { functionDeclarations };
339+
if (functionDeclarations.length) {
340+
tools.push({ functionDeclarations });
341+
}
342+
return tools;
334343
},
335344
},
336345
tool_choice: {
@@ -388,6 +397,24 @@ interface GoogleGenerateContentResponse {
388397
category: string;
389398
probability: string;
390399
}[];
400+
groundingMetadata?: {
401+
webSearchQueries?: string[];
402+
searchEntryPoint?: {
403+
renderedContent: string;
404+
};
405+
groundingSupports?: Array<{
406+
segment: {
407+
startIndex: number;
408+
endIndex: number;
409+
text: string;
410+
};
411+
groundingChunkIndices: number[];
412+
confidenceScores: number[];
413+
}>;
414+
retrievalMetadata?: {
415+
webDynamicRetrievalScore: number;
416+
};
417+
};
391418
}[];
392419
promptFeedback: {
393420
safetyRatings: {
@@ -423,8 +450,15 @@ export const GoogleErrorResponseTransform: (
423450

424451
export const GoogleChatCompleteResponseTransform: (
425452
response: GoogleGenerateContentResponse | GoogleErrorResponse,
426-
responseStatus: number
427-
) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => {
453+
responseStatus: number,
454+
responseHeaders: Headers,
455+
strictOpenAiCompliance: boolean
456+
) => ChatCompletionResponse | ErrorResponse = (
457+
response,
458+
responseStatus,
459+
_responseHeaders,
460+
strictOpenAiCompliance
461+
) => {
428462
if (responseStatus !== 200) {
429463
const errorResponse = GoogleErrorResponseTransform(
430464
response as GoogleErrorResponse
@@ -468,6 +502,9 @@ export const GoogleChatCompleteResponseTransform: (
468502
message: message,
469503
index: generation.index ?? idx,
470504
finish_reason: generation.finishReason,
505+
...(!strictOpenAiCompliance && generation.groundingMetadata
506+
? { groundingMetadata: generation.groundingMetadata }
507+
: {}),
471508
};
472509
}) ?? [],
473510
usage: {
@@ -483,8 +520,15 @@ export const GoogleChatCompleteResponseTransform: (
483520

484521
export const GoogleChatCompleteStreamChunkTransform: (
485522
response: string,
486-
fallbackId: string
487-
) => string = (responseChunk, fallbackId) => {
523+
fallbackId: string,
524+
streamState: any,
525+
strictOpenAiCompliance: boolean
526+
) => string = (
527+
responseChunk,
528+
fallbackId,
529+
_streamState,
530+
strictOpenAiCompliance
531+
) => {
488532
let chunk = responseChunk.trim();
489533
if (chunk.startsWith('[')) {
490534
chunk = chunk.slice(1);
@@ -541,6 +585,9 @@ export const GoogleChatCompleteStreamChunkTransform: (
541585
delta: message,
542586
index: generation.index ?? index,
543587
finish_reason: generation.finishReason,
588+
...(!strictOpenAiCompliance && generation.groundingMetadata
589+
? { groundingMetadata: generation.groundingMetadata }
590+
: {}),
544591
};
545592
}) ?? [],
546593
usage: {

src/types/requestBody.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ export interface Tool extends AnthropicPromptCache {
301301
/** The name of the function. */
302302
type: string;
303303
/** A description of the function. */
304-
function?: Function;
304+
function: Function;
305305
}
306306

307307
/**

0 commit comments

Comments
 (0)