Skip to content

Commit a6f14f4

Browse files
committed
feat(gemini): support tool_choice and improve error handling
1 parent 9aeef6a commit a6f14f4

File tree

1 file changed

+105
-6
lines changed

1 file changed

+105
-6
lines changed

relay/channel/gemini/relay-gemini.go

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,13 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
356356
})
357357
}
358358
geminiRequest.SetTools(geminiTools)
359+
360+
// [NEW] Convert OpenAI tool_choice to Gemini toolConfig.functionCallingConfig
361+
// Mapping: "auto" -> "AUTO", "none" -> "NONE", "required" -> "ANY"
362+
// Object format: {"type": "function", "function": {"name": "xxx"}} -> "ANY" + allowedFunctionNames
363+
if textRequest.ToolChoice != nil {
364+
geminiRequest.ToolConfig = convertToolChoiceToGeminiConfig(textRequest.ToolChoice)
365+
}
359366
}
360367

361368
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
@@ -960,6 +967,24 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
960967
choice.FinishReason = constant.FinishReasonStop
961968
case "MAX_TOKENS":
962969
choice.FinishReason = constant.FinishReasonLength
970+
case "SAFETY":
971+
// Safety filter triggered
972+
choice.FinishReason = constant.FinishReasonContentFilter
973+
case "RECITATION":
974+
// Recitation (citation) detected
975+
choice.FinishReason = constant.FinishReasonContentFilter
976+
case "BLOCKLIST":
977+
// Blocklist triggered
978+
choice.FinishReason = constant.FinishReasonContentFilter
979+
case "PROHIBITED_CONTENT":
980+
// Prohibited content detected
981+
choice.FinishReason = constant.FinishReasonContentFilter
982+
case "SPII":
983+
// Sensitive personally identifiable information
984+
choice.FinishReason = constant.FinishReasonContentFilter
985+
case "OTHER":
986+
// Other reasons
987+
choice.FinishReason = constant.FinishReasonContentFilter
963988
default:
964989
choice.FinishReason = constant.FinishReasonContentFilter
965990
}
@@ -997,6 +1022,18 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
9971022
choice.FinishReason = &constant.FinishReasonStop
9981023
case "MAX_TOKENS":
9991024
choice.FinishReason = &constant.FinishReasonLength
1025+
case "SAFETY":
1026+
choice.FinishReason = &constant.FinishReasonContentFilter
1027+
case "RECITATION":
1028+
choice.FinishReason = &constant.FinishReasonContentFilter
1029+
case "BLOCKLIST":
1030+
choice.FinishReason = &constant.FinishReasonContentFilter
1031+
case "PROHIBITED_CONTENT":
1032+
choice.FinishReason = &constant.FinishReasonContentFilter
1033+
case "SPII":
1034+
choice.FinishReason = &constant.FinishReasonContentFilter
1035+
case "OTHER":
1036+
choice.FinishReason = &constant.FinishReasonContentFilter
10001037
default:
10011038
choice.FinishReason = &constant.FinishReasonContentFilter
10021039
}
@@ -1214,12 +1251,20 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
12141251
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
12151252
}
12161253
if len(geminiResponse.Candidates) == 0 {
1217-
//return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
1218-
//if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
1219-
// return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
1220-
//} else {
1221-
// return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
1222-
//}
1254+
// [FIX] Return meaningful error when Candidates is empty
1255+
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
1256+
return nil, types.NewOpenAIError(
1257+
errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason),
1258+
types.ErrorCodePromptBlocked,
1259+
http.StatusBadRequest,
1260+
)
1261+
} else {
1262+
return nil, types.NewOpenAIError(
1263+
errors.New("empty response from Gemini API"),
1264+
types.ErrorCodeEmptyResponse,
1265+
http.StatusInternalServerError,
1266+
)
1267+
}
12231268
}
12241269
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
12251270
fullTextResponse.Model = info.UpstreamModelName
@@ -1362,3 +1407,57 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
13621407

13631408
return usage, nil
13641409
}
1410+
1411+
// convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig
1412+
// OpenAI tool_choice values:
1413+
// - "auto": Let the model decide (default)
1414+
// - "none": Don't call any tools
1415+
// - "required": Must call at least one tool
1416+
// - {"type": "function", "function": {"name": "xxx"}}: Call specific function
1417+
//
1418+
// Gemini functionCallingConfig.mode values:
1419+
// - "AUTO": Model decides whether to call functions
1420+
// - "NONE": Model won't call functions
1421+
// - "ANY": Model must call at least one function
1422+
func convertToolChoiceToGeminiConfig(toolChoice any) *dto.ToolConfig {
1423+
if toolChoice == nil {
1424+
return nil
1425+
}
1426+
1427+
config := &dto.ToolConfig{
1428+
FunctionCallingConfig: &dto.FunctionCallingConfig{},
1429+
}
1430+
1431+
// Handle string values: "auto", "none", "required"
1432+
if toolChoiceStr, ok := toolChoice.(string); ok {
1433+
switch toolChoiceStr {
1434+
case "auto":
1435+
config.FunctionCallingConfig.Mode = "AUTO"
1436+
case "none":
1437+
config.FunctionCallingConfig.Mode = "NONE"
1438+
case "required":
1439+
config.FunctionCallingConfig.Mode = "ANY"
1440+
default:
1441+
// Unknown value, default to AUTO
1442+
config.FunctionCallingConfig.Mode = "AUTO"
1443+
}
1444+
return config
1445+
}
1446+
1447+
// Handle object value: {"type": "function", "function": {"name": "xxx"}}
1448+
if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
1449+
if toolChoiceMap["type"] == "function" {
1450+
config.FunctionCallingConfig.Mode = "ANY"
1451+
1452+
// Extract function name if specified
1453+
if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
1454+
if name, ok := function["name"].(string); ok && name != "" {
1455+
config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
1456+
}
1457+
}
1458+
}
1459+
return config
1460+
}
1461+
1462+
return nil
1463+
}

0 commit comments

Comments
 (0)