Skip to content

Commit d78bbb9

Browse files
hustxiayangyuzisun
andauthored
feat: enable enterprise web search for gemini models (#1526)
**Description** This is to enable feature `Web Grounding for Enterprise`: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise **Related Issues/PRs (if applicable)** #1417 --------- Signed-off-by: yxia216 <[email protected]> Co-authored-by: Dan Sun <[email protected]>
1 parent 03919d3 commit d78bbb9

File tree

5 files changed

+262
-4
lines changed

5 files changed

+262
-4
lines changed

internal/apischema/openai/openai.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,8 +1071,9 @@ type StreamOptions struct {
10711071
type ToolType string
10721072

10731073
const (
1074-
ToolTypeFunction ToolType = "function"
1075-
ToolTypeImageGeneration ToolType = "image_generation"
1074+
ToolTypeFunction ToolType = "function"
1075+
ToolTypeImageGeneration ToolType = "image_generation"
1076+
ToolTypeEnterpriseWebSearch ToolType = "enterprise_search"
10761077
)
10771078

10781079
type Tool struct {
@@ -1314,6 +1315,10 @@ type ChatCompletionResponseChoiceMessage struct {
13141315
// List of ratings for the safety of a response candidate. There is at most one rating per category.
13151316
// https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1/GenerateContentResponse#SafetyRating
13161317
SafetyRatings []*genai.SafetyRating `json:"safety_ratings,omitempty"`
1318+
1319+
// GroundingMetadata specifies sources used to ground generated content.
1320+
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1beta1/GroundingMetadata
1321+
GroundingMetadata *genai.GroundingMetadata `json:"grounding_metadata,omitempty"`
13171322
}
13181323

13191324
// URLCitation contains citation information for web search results.

internal/apischema/openai/openai_test.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,133 @@ func TestChatCompletionRequest(t *testing.T) {
11931193
},
11941194
},
11951195
},
1196+
{
1197+
name: "enterprise search tool",
1198+
jsonStr: `{
1199+
"model": "gemini-1.5-pro",
1200+
"messages": [
1201+
{
1202+
"role": "user",
1203+
"content": "Hello with enterprise search!"
1204+
}
1205+
],
1206+
"tools": [
1207+
{
1208+
"type": "enterprise_search"
1209+
}
1210+
]
1211+
}`,
1212+
expected: &ChatCompletionRequest{
1213+
Model: "gemini-1.5-pro",
1214+
Messages: []ChatCompletionMessageParamUnion{
1215+
{
1216+
OfUser: &ChatCompletionUserMessageParam{
1217+
Role: ChatMessageRoleUser,
1218+
Content: StringOrUserRoleContentUnion{Value: "Hello with enterprise search!"},
1219+
},
1220+
},
1221+
},
1222+
Tools: []Tool{
1223+
{
1224+
Type: ToolTypeEnterpriseWebSearch,
1225+
},
1226+
},
1227+
},
1228+
},
1229+
{
1230+
name: "mixed function and enterprise search tools",
1231+
jsonStr: `{
1232+
"model": "gemini-1.5-pro",
1233+
"messages": [
1234+
{
1235+
"role": "user",
1236+
"content": "Mixed tools test"
1237+
}
1238+
],
1239+
"tools": [
1240+
{
1241+
"type": "function",
1242+
"function": {
1243+
"name": "get_weather",
1244+
"description": "Get current weather"
1245+
}
1246+
},
1247+
{
1248+
"type": "enterprise_search"
1249+
}
1250+
]
1251+
}`,
1252+
expected: &ChatCompletionRequest{
1253+
Model: "gemini-1.5-pro",
1254+
Messages: []ChatCompletionMessageParamUnion{
1255+
{
1256+
OfUser: &ChatCompletionUserMessageParam{
1257+
Role: ChatMessageRoleUser,
1258+
Content: StringOrUserRoleContentUnion{Value: "Mixed tools test"},
1259+
},
1260+
},
1261+
},
1262+
Tools: []Tool{
1263+
{
1264+
Type: ToolTypeFunction,
1265+
Function: &FunctionDefinition{
1266+
Name: "get_weather",
1267+
Description: "Get current weather",
1268+
},
1269+
},
1270+
{
1271+
Type: ToolTypeEnterpriseWebSearch,
1272+
},
1273+
},
1274+
},
1275+
},
1276+
{
1277+
name: "enterprise search with vendor fields",
1278+
jsonStr: `{
1279+
"model": "gemini-1.5-pro",
1280+
"messages": [
1281+
{
1282+
"role": "user",
1283+
"content": "Combined enterprise search and safety settings"
1284+
}
1285+
],
1286+
"tools": [
1287+
{
1288+
"type": "enterprise_search"
1289+
}
1290+
],
1291+
"safetySettings": [
1292+
{
1293+
"category": "HARM_CATEGORY_HARASSMENT",
1294+
"threshold": "BLOCK_ONLY_HIGH"
1295+
}
1296+
]
1297+
}`,
1298+
expected: &ChatCompletionRequest{
1299+
Model: "gemini-1.5-pro",
1300+
Messages: []ChatCompletionMessageParamUnion{
1301+
{
1302+
OfUser: &ChatCompletionUserMessageParam{
1303+
Role: ChatMessageRoleUser,
1304+
Content: StringOrUserRoleContentUnion{Value: "Combined enterprise search and safety settings"},
1305+
},
1306+
},
1307+
},
1308+
Tools: []Tool{
1309+
{
1310+
Type: ToolTypeEnterpriseWebSearch,
1311+
},
1312+
},
1313+
GCPVertexAIVendorFields: &GCPVertexAIVendorFields{
1314+
SafetySettings: []*genai.SafetySetting{
1315+
{
1316+
Category: genai.HarmCategoryHarassment,
1317+
Threshold: genai.HarmBlockThresholdBlockOnlyHigh,
1318+
},
1319+
},
1320+
},
1321+
},
1322+
},
11961323
}
11971324

11981325
for _, tc := range testCases {

internal/translator/gemini_helper.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ func openAIToolsToGeminiTools(openaiTools []openai.Tool, parametersJSONSchemaAva
364364
if len(openaiTools) == 0 {
365365
return nil, nil
366366
}
367+
368+
var genaiTools []genai.Tool
369+
367370
var functionDecls []*genai.FunctionDeclaration
368371

369372
for _, tool := range openaiTools {
@@ -395,14 +398,27 @@ func openAIToolsToGeminiTools(openaiTools []openai.Tool, parametersJSONSchemaAva
395398
}
396399
case openai.ToolTypeImageGeneration:
397400
return nil, fmt.Errorf("tool-type image generation not supported yet when translating OpenAI req to Gemini")
401+
case openai.ToolTypeEnterpriseWebSearch:
402+
genaiTools = append(genaiTools, genai.Tool{
403+
EnterpriseWebSearch: &genai.EnterpriseWebSearch{},
404+
})
398405
default:
399406
return nil, fmt.Errorf("unsupported tool type: %s", tool.Type)
400407
}
401408
}
402-
if len(functionDecls) == 0 {
409+
// Only return nil if there are no tools at all (neither function declarations nor other tools)
410+
if len(functionDecls) == 0 && len(genaiTools) == 0 {
403411
return nil, nil
404412
}
405-
return []genai.Tool{{FunctionDeclarations: functionDecls}}, nil
413+
414+
// Only append function declarations if there are any
415+
if len(functionDecls) > 0 {
416+
genaiTools = append(genaiTools, genai.Tool{
417+
FunctionDeclarations: functionDecls,
418+
})
419+
}
420+
421+
return genaiTools, nil
406422
}
407423

408424
// openAIToolChoiceToGeminiToolConfig converts OpenAI tool_choice to Gemini ToolConfig.
@@ -688,6 +704,14 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode
688704
choice.Message.SafetyRatings = candidate.SafetyRatings
689705
}
690706

707+
if candidate.GroundingMetadata != nil {
708+
if choice.Message.Role == "" {
709+
choice.Message.Role = openai.ChatMessageRoleAssistant
710+
}
711+
712+
choice.Message.GroundingMetadata = candidate.GroundingMetadata
713+
}
714+
691715
// Handle logprobs if available.
692716
if candidate.LogprobsResult != nil {
693717
choice.Logprobs = geminiLogprobsToOpenAILogprobs(*candidate.LogprobsResult)

internal/translator/gemini_helper_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,58 @@ func TestOpenAIToolsToGeminiTools(t *testing.T) {
15401540
parametersJSONSchemaAvailable: false,
15411541
expected: nil,
15421542
},
1543+
{
1544+
name: "enterprise search tool only",
1545+
openaiTools: []openai.Tool{
1546+
{
1547+
Type: openai.ToolTypeEnterpriseWebSearch,
1548+
},
1549+
},
1550+
parametersJSONSchemaAvailable: false,
1551+
expected: []genai.Tool{
1552+
{
1553+
EnterpriseWebSearch: &genai.EnterpriseWebSearch{},
1554+
},
1555+
},
1556+
},
1557+
{
1558+
name: "mixed function and enterprise search tools",
1559+
openaiTools: []openai.Tool{
1560+
{
1561+
Type: openai.ToolTypeFunction,
1562+
Function: &openai.FunctionDefinition{
1563+
Name: "get_weather",
1564+
Description: "Get current weather",
1565+
Parameters: funcParams,
1566+
},
1567+
},
1568+
{
1569+
Type: openai.ToolTypeEnterpriseWebSearch,
1570+
},
1571+
},
1572+
parametersJSONSchemaAvailable: false,
1573+
expected: []genai.Tool{
1574+
{
1575+
EnterpriseWebSearch: &genai.EnterpriseWebSearch{},
1576+
},
1577+
{
1578+
FunctionDeclarations: []*genai.FunctionDeclaration{
1579+
{
1580+
Name: "get_weather",
1581+
Description: "Get current weather",
1582+
Parameters: &genai.Schema{
1583+
Type: "object",
1584+
Properties: map[string]*genai.Schema{
1585+
"a": {Type: "integer"},
1586+
"b": {Type: "integer"},
1587+
},
1588+
Required: []string{"a", "b"},
1589+
},
1590+
},
1591+
},
1592+
},
1593+
},
1594+
},
15431595
}
15441596

15451597
for _, tc := range tests {

internal/translator/openai_gcpvertexai_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,28 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T)
335335
}
336336
}`)
337337

338+
wantBdyWithEnterpriseWebSearch := []byte(`{
339+
"contents": [
340+
{
341+
"parts": [
342+
{
343+
"text": "Test with web grounding for enterprise"
344+
}
345+
],
346+
"role": "user"
347+
}
348+
],
349+
"tools": [
350+
{
351+
"enterpriseWebSearch": {}
352+
}
353+
],
354+
"generation_config": {
355+
"maxOutputTokens": 1024,
356+
"temperature": 0.7
357+
}
358+
}`)
359+
338360
tests := []struct {
339361
name string
340362
modelNameOverride internalapi.ModelNameOverride
@@ -739,6 +761,34 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T)
739761
},
740762
wantBody: wantBdyWithGuidedRegex,
741763
},
764+
{
765+
name: "Request with gcp web grounding for enterprise",
766+
input: openai.ChatCompletionRequest{
767+
Model: "gemini-1.5-pro",
768+
Temperature: ptr.To(0.7),
769+
MaxTokens: ptr.To(int64(1024)),
770+
Messages: []openai.ChatCompletionMessageParamUnion{
771+
{
772+
OfUser: &openai.ChatCompletionUserMessageParam{
773+
Role: openai.ChatMessageRoleUser,
774+
Content: openai.StringOrUserRoleContentUnion{Value: "Test with web grounding for enterprise"},
775+
},
776+
},
777+
},
778+
Tools: []openai.Tool{
779+
{
780+
Type: "enterprise_search",
781+
},
782+
},
783+
},
784+
onRetry: false,
785+
wantError: false,
786+
wantHeaderMut: []internalapi.Header{
787+
{":path", "publishers/google/models/gemini-1.5-pro:generateContent"},
788+
{"content-length", "190"},
789+
},
790+
wantBody: wantBdyWithEnterpriseWebSearch,
791+
},
742792
}
743793

744794
for _, tc := range tests {

0 commit comments

Comments
 (0)