|
7 | 7 | ToolCall,
|
8 | 8 | ToolChoice,
|
9 | 9 | } from '../../types/requestBody';
|
| 10 | +import { buildGoogleSearchRetrievalTool } from '../google-vertex-ai/chatComplete'; |
10 | 11 | import { derefer, getMimeType } from '../google-vertex-ai/utils';
|
11 | 12 | import {
|
12 | 13 | ChatCompletionResponse,
|
@@ -325,12 +326,20 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
|
325 | 326 | default: '',
|
326 | 327 | transform: (params: Params) => {
|
327 | 328 | const functionDeclarations: any = [];
|
| 329 | + const tools: any = []; |
328 | 330 | params.tools?.forEach((tool) => {
|
329 | 331 | if (tool.type === 'function') {
|
330 |
| - functionDeclarations.push(tool.function); |
| 332 | + if (tool.function.name === 'googleSearchRetrieval') { |
| 333 | + tools.push(buildGoogleSearchRetrievalTool(tool)); |
| 334 | + } else { |
| 335 | + functionDeclarations.push(tool.function); |
| 336 | + } |
331 | 337 | }
|
332 | 338 | });
|
333 |
| - return { functionDeclarations }; |
| 339 | + if (functionDeclarations.length) { |
| 340 | + tools.push({ functionDeclarations }); |
| 341 | + } |
| 342 | + return tools; |
334 | 343 | },
|
335 | 344 | },
|
336 | 345 | tool_choice: {
|
@@ -388,6 +397,24 @@ interface GoogleGenerateContentResponse {
|
388 | 397 | category: string;
|
389 | 398 | probability: string;
|
390 | 399 | }[];
|
| 400 | + groundingMetadata?: { |
| 401 | + webSearchQueries?: string[]; |
| 402 | + searchEntryPoint?: { |
| 403 | + renderedContent: string; |
| 404 | + }; |
| 405 | + groundingSupports?: Array<{ |
| 406 | + segment: { |
| 407 | + startIndex: number; |
| 408 | + endIndex: number; |
| 409 | + text: string; |
| 410 | + }; |
| 411 | + groundingChunkIndices: number[]; |
| 412 | + confidenceScores: number[]; |
| 413 | + }>; |
| 414 | + retrievalMetadata?: { |
| 415 | + webDynamicRetrievalScore: number; |
| 416 | + }; |
| 417 | + }; |
391 | 418 | }[];
|
392 | 419 | promptFeedback: {
|
393 | 420 | safetyRatings: {
|
@@ -423,8 +450,15 @@ export const GoogleErrorResponseTransform: (
|
423 | 450 |
|
424 | 451 | export const GoogleChatCompleteResponseTransform: (
|
425 | 452 | response: GoogleGenerateContentResponse | GoogleErrorResponse,
|
426 |
| - responseStatus: number |
427 |
| -) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { |
| 453 | + responseStatus: number, |
| 454 | + responseHeaders: Headers, |
| 455 | + strictOpenAiCompliance: boolean |
| 456 | +) => ChatCompletionResponse | ErrorResponse = ( |
| 457 | + response, |
| 458 | + responseStatus, |
| 459 | + _responseHeaders, |
| 460 | + strictOpenAiCompliance |
| 461 | +) => { |
428 | 462 | if (responseStatus !== 200) {
|
429 | 463 | const errorResponse = GoogleErrorResponseTransform(
|
430 | 464 | response as GoogleErrorResponse
|
@@ -468,6 +502,9 @@ export const GoogleChatCompleteResponseTransform: (
|
468 | 502 | message: message,
|
469 | 503 | index: generation.index ?? idx,
|
470 | 504 | finish_reason: generation.finishReason,
|
| 505 | + ...(!strictOpenAiCompliance && generation.groundingMetadata |
| 506 | + ? { groundingMetadata: generation.groundingMetadata } |
| 507 | + : {}), |
471 | 508 | };
|
472 | 509 | }) ?? [],
|
473 | 510 | usage: {
|
|
0 commit comments