Skip to content

Commit b6a25d9

Browse files
committed
feat(gemini): 支持 tool_choice 参数转换,优化错误处理
1 parent 9aeef6a commit b6a25d9

File tree

1 file changed

+121
-7
lines changed

1 file changed

+121
-7
lines changed

relay/channel/gemini/relay-gemini.go

Lines changed: 121 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,13 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
356356
})
357357
}
358358
geminiRequest.SetTools(geminiTools)
359+
360+
// [NEW] Convert OpenAI tool_choice to Gemini toolConfig.functionCallingConfig
361+
// Mapping: "auto" -> "AUTO", "none" -> "NONE", "required" -> "ANY"
362+
// Object format: {"type": "function", "function": {"name": "xxx"}} -> "ANY" + allowedFunctionNames
363+
if textRequest.ToolChoice != nil {
364+
geminiRequest.ToolConfig = convertToolChoiceToGeminiConfig(textRequest.ToolChoice)
365+
}
359366
}
360367

361368
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
@@ -960,6 +967,24 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
960967
choice.FinishReason = constant.FinishReasonStop
961968
case "MAX_TOKENS":
962969
choice.FinishReason = constant.FinishReasonLength
970+
case "SAFETY":
971+
// Safety filter triggered
972+
choice.FinishReason = constant.FinishReasonContentFilter
973+
case "RECITATION":
974+
// Recitation (citation) detected
975+
choice.FinishReason = constant.FinishReasonContentFilter
976+
case "BLOCKLIST":
977+
// Blocklist triggered
978+
choice.FinishReason = constant.FinishReasonContentFilter
979+
case "PROHIBITED_CONTENT":
980+
// Prohibited content detected
981+
choice.FinishReason = constant.FinishReasonContentFilter
982+
case "SPII":
983+
// Sensitive personally identifiable information
984+
choice.FinishReason = constant.FinishReasonContentFilter
985+
case "OTHER":
986+
// Other reasons
987+
choice.FinishReason = constant.FinishReasonContentFilter
963988
default:
964989
choice.FinishReason = constant.FinishReasonContentFilter
965990
}
@@ -991,13 +1016,34 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
9911016
isTools := false
9921017
isThought := false
9931018
if candidate.FinishReason != nil {
994-
// p := GeminiConvertFinishReason(*candidate.FinishReason)
1019+
// Map Gemini FinishReason to OpenAI finish_reason
9951020
switch *candidate.FinishReason {
9961021
case "STOP":
1022+
// Normal completion
9971023
choice.FinishReason = &constant.FinishReasonStop
9981024
case "MAX_TOKENS":
1025+
// Reached maximum token limit
9991026
choice.FinishReason = &constant.FinishReasonLength
1027+
case "SAFETY":
1028+
// Safety filter triggered
1029+
choice.FinishReason = &constant.FinishReasonContentFilter
1030+
case "RECITATION":
1031+
// Recitation (citation) detected
1032+
choice.FinishReason = &constant.FinishReasonContentFilter
1033+
case "BLOCKLIST":
1034+
// Blocklist triggered
1035+
choice.FinishReason = &constant.FinishReasonContentFilter
1036+
case "PROHIBITED_CONTENT":
1037+
// Prohibited content detected
1038+
choice.FinishReason = &constant.FinishReasonContentFilter
1039+
case "SPII":
1040+
// Sensitive personally identifiable information
1041+
choice.FinishReason = &constant.FinishReasonContentFilter
1042+
case "OTHER":
1043+
// Other reasons
1044+
choice.FinishReason = &constant.FinishReasonContentFilter
10001045
default:
1046+
// Unknown reason, treat as content filter
10011047
choice.FinishReason = &constant.FinishReasonContentFilter
10021048
}
10031049
}
@@ -1214,12 +1260,20 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
12141260
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
12151261
}
12161262
if len(geminiResponse.Candidates) == 0 {
1217-
//return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
1218-
//if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
1219-
// return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
1220-
//} else {
1221-
// return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
1222-
//}
1263+
// [FIX] Return meaningful error when Candidates is empty
1264+
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
1265+
return nil, types.NewOpenAIError(
1266+
errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason),
1267+
types.ErrorCodePromptBlocked,
1268+
http.StatusBadRequest,
1269+
)
1270+
} else {
1271+
return nil, types.NewOpenAIError(
1272+
errors.New("empty response from Gemini API"),
1273+
types.ErrorCodeEmptyResponse,
1274+
http.StatusInternalServerError,
1275+
)
1276+
}
12231277
}
12241278
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
12251279
fullTextResponse.Model = info.UpstreamModelName
@@ -1362,3 +1416,63 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
13621416

13631417
return usage, nil
13641418
}
1419+
1420+
// convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig
1421+
// OpenAI tool_choice values:
1422+
// - "auto": Let the model decide (default)
1423+
// - "none": Don't call any tools
1424+
// - "required": Must call at least one tool
1425+
// - {"type": "function", "function": {"name": "xxx"}}: Call specific function
1426+
//
1427+
// Gemini functionCallingConfig.mode values:
1428+
// - "AUTO": Model decides whether to call functions
1429+
// - "NONE": Model won't call functions
1430+
// - "ANY": Model must call at least one function
1431+
func convertToolChoiceToGeminiConfig(toolChoice any) *dto.ToolConfig {
1432+
if toolChoice == nil {
1433+
return nil
1434+
}
1435+
1436+
// Handle string values: "auto", "none", "required"
1437+
if toolChoiceStr, ok := toolChoice.(string); ok {
1438+
config := &dto.ToolConfig{
1439+
FunctionCallingConfig: &dto.FunctionCallingConfig{},
1440+
}
1441+
switch toolChoiceStr {
1442+
case "auto":
1443+
config.FunctionCallingConfig.Mode = "AUTO"
1444+
case "none":
1445+
config.FunctionCallingConfig.Mode = "NONE"
1446+
case "required":
1447+
config.FunctionCallingConfig.Mode = "ANY"
1448+
default:
1449+
// Unknown string value, default to AUTO
1450+
config.FunctionCallingConfig.Mode = "AUTO"
1451+
}
1452+
return config
1453+
}
1454+
1455+
// Handle object value: {"type": "function", "function": {"name": "xxx"}}
1456+
if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
1457+
if toolChoiceMap["type"] == "function" {
1458+
config := &dto.ToolConfig{
1459+
FunctionCallingConfig: &dto.FunctionCallingConfig{
1460+
Mode: "ANY",
1461+
},
1462+
}
1463+
// Extract function name if specified
1464+
if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
1465+
if name, ok := function["name"].(string); ok && name != "" {
1466+
config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
1467+
}
1468+
}
1469+
return config
1470+
}
1471+
// Unsupported map structure (type is not "function"), return nil
1472+
return nil
1473+
}
1474+
1475+
// Unsupported type, return nil
1476+
return nil
1477+
}
1478+

0 commit comments

Comments
 (0)