Skip to content

Commit b0cbf71

Browse files
Merge pull request #1107 from QuantumNous/gemini-relay
Gemini 格式
2 parents 40e6405 + 156ad5c commit b0cbf71

File tree

9 files changed

+365
-4
lines changed

9 files changed

+365
-4
lines changed

controller/relay.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
4040
err = relay.EmbeddingHelper(c)
4141
case relayconstant.RelayModeResponses:
4242
err = relay.ResponsesHelper(c)
43+
case relayconstant.RelayModeGemini:
44+
err = relay.GeminiHelper(c)
4345
default:
4446
err = relay.TextHelper(c)
4547
}

middleware/auth.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
package middleware
22

33
import (
4-
"github.com/gin-contrib/sessions"
5-
"github.com/gin-gonic/gin"
64
"net/http"
75
"one-api/common"
86
"one-api/model"
97
"strconv"
108
"strings"
9+
10+
"github.com/gin-contrib/sessions"
11+
"github.com/gin-gonic/gin"
1112
)
1213

1314
func validUserInfo(username string, role int) bool {
@@ -182,6 +183,13 @@ func TokenAuth() func(c *gin.Context) {
182183
c.Request.Header.Set("Authorization", "Bearer "+key)
183184
}
184185
}
186+
// gemini api 从query中获取key
187+
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
188+
skKey := c.Query("key")
189+
if skKey != "" {
190+
c.Request.Header.Set("Authorization", "Bearer "+skKey)
191+
}
192+
}
185193
key := c.Request.Header.Get("Authorization")
186194
parts := make([]string, 0)
187195
key = strings.TrimPrefix(key, "Bearer ")

middleware/distributor.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
162162
}
163163
c.Set("platform", string(constant.TaskPlatformSuno))
164164
c.Set("relay_mode", relayMode)
165+
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
166+
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
167+
relayMode := relayconstant.RelayModeGemini
168+
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
169+
if modelName != "" {
170+
modelRequest.Model = modelName
171+
}
172+
c.Set("relay_mode", relayMode)
165173
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
166174
err = common.UnmarshalBodyReusable(c, &modelRequest)
167175
}
@@ -244,3 +252,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
244252
c.Set("bot_id", channel.Other)
245253
}
246254
}
255+
256+
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
257+
// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent
258+
// 输出: gemini-2.0-flash
259+
func extractModelNameFromGeminiPath(path string) string {
260+
// 查找 "/models/" 的位置
261+
modelsPrefix := "/models/"
262+
modelsIndex := strings.Index(path, modelsPrefix)
263+
if modelsIndex == -1 {
264+
return ""
265+
}
266+
267+
// 从 "/models/" 之后开始提取
268+
startIndex := modelsIndex + len(modelsPrefix)
269+
if startIndex >= len(path) {
270+
return ""
271+
}
272+
273+
// 查找 ":" 的位置,模型名在 ":" 之前
274+
colonIndex := strings.Index(path[startIndex:], ":")
275+
if colonIndex == -1 {
276+
// 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分
277+
return path[startIndex:]
278+
}
279+
280+
// 返回模型名部分
281+
return path[startIndex : startIndex+colonIndex]
282+
}

relay/channel/gemini/adaptor.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"one-api/dto"
1111
"one-api/relay/channel"
1212
relaycommon "one-api/relay/common"
13+
"one-api/relay/constant"
1314
"one-api/service"
1415
"one-api/setting/model_setting"
1516
"strings"
@@ -165,6 +166,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
165166
}
166167

167168
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
169+
if info.RelayMode == constant.RelayModeGemini {
170+
if info.IsStream {
171+
return GeminiTextGenerationStreamHandler(c, resp, info)
172+
} else {
173+
return GeminiTextGenerationHandler(c, resp, info)
174+
}
175+
}
176+
168177
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
169178
return GeminiImageHandler(c, resp, info)
170179
}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package gemini
2+
3+
import (
4+
"encoding/json"
5+
"io"
6+
"net/http"
7+
"one-api/common"
8+
"one-api/dto"
9+
relaycommon "one-api/relay/common"
10+
"one-api/relay/helper"
11+
"one-api/service"
12+
13+
"github.com/gin-gonic/gin"
14+
)
15+
16+
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
17+
// 读取响应体
18+
responseBody, err := io.ReadAll(resp.Body)
19+
if err != nil {
20+
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
21+
}
22+
err = resp.Body.Close()
23+
if err != nil {
24+
return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
25+
}
26+
27+
if common.DebugEnabled {
28+
println(string(responseBody))
29+
}
30+
31+
// 解析为 Gemini 原生响应格式
32+
var geminiResponse GeminiChatResponse
33+
err = common.DecodeJson(responseBody, &geminiResponse)
34+
if err != nil {
35+
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
36+
}
37+
38+
// 检查是否有候选响应
39+
if len(geminiResponse.Candidates) == 0 {
40+
return nil, &dto.OpenAIErrorWithStatusCode{
41+
Error: dto.OpenAIError{
42+
Message: "No candidates returned",
43+
Type: "server_error",
44+
Param: "",
45+
Code: 500,
46+
},
47+
StatusCode: resp.StatusCode,
48+
}
49+
}
50+
51+
// 计算使用量(基于 UsageMetadata)
52+
usage := dto.Usage{
53+
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
54+
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
55+
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
56+
}
57+
58+
// 直接返回 Gemini 原生格式的 JSON 响应
59+
jsonResponse, err := json.Marshal(geminiResponse)
60+
if err != nil {
61+
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
62+
}
63+
64+
// 设置响应头并写入响应
65+
c.Writer.Header().Set("Content-Type", "application/json")
66+
c.Writer.WriteHeader(resp.StatusCode)
67+
_, err = c.Writer.Write(jsonResponse)
68+
if err != nil {
69+
return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
70+
}
71+
72+
return &usage, nil
73+
}
74+
75+
func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
76+
var usage = &dto.Usage{}
77+
var imageCount int
78+
79+
helper.SetEventStreamHeaders(c)
80+
81+
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
82+
var geminiResponse GeminiChatResponse
83+
err := common.DecodeJsonStr(data, &geminiResponse)
84+
if err != nil {
85+
common.LogError(c, "error unmarshalling stream response: "+err.Error())
86+
return false
87+
}
88+
89+
// 统计图片数量
90+
for _, candidate := range geminiResponse.Candidates {
91+
for _, part := range candidate.Content.Parts {
92+
if part.InlineData != nil && part.InlineData.MimeType != "" {
93+
imageCount++
94+
}
95+
}
96+
}
97+
98+
// 更新使用量统计
99+
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
100+
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
101+
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
102+
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
103+
}
104+
105+
// 直接发送 GeminiChatResponse 响应
106+
err = helper.ObjectData(c, geminiResponse)
107+
if err != nil {
108+
common.LogError(c, err.Error())
109+
}
110+
111+
return true
112+
})
113+
114+
if imageCount != 0 {
115+
if usage.CompletionTokens == 0 {
116+
usage.CompletionTokens = imageCount * 258
117+
}
118+
}
119+
120+
// 计算最终使用量
121+
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
122+
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
123+
124+
// 结束流式响应
125+
helper.Done(c)
126+
127+
return usage, nil
128+
}

relay/channel/vertex/adaptor.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"one-api/relay/channel/gemini"
1313
"one-api/relay/channel/openai"
1414
relaycommon "one-api/relay/common"
15+
"one-api/relay/constant"
1516
"one-api/setting/model_setting"
1617
"strings"
1718

@@ -201,7 +202,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
201202
case RequestModeClaude:
202203
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
203204
case RequestModeGemini:
204-
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
205+
if info.RelayMode == constant.RelayModeGemini {
206+
usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info)
207+
} else {
208+
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
209+
}
205210
case RequestModeLlama:
206211
err, usage = openai.OaiStreamHandler(c, resp, info)
207212
}
@@ -210,7 +215,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
210215
case RequestModeClaude:
211216
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
212217
case RequestModeGemini:
213-
err, usage = gemini.GeminiChatHandler(c, resp, info)
218+
if info.RelayMode == constant.RelayModeGemini {
219+
usage, err = gemini.GeminiTextGenerationHandler(c, resp, info)
220+
} else {
221+
err, usage = gemini.GeminiChatHandler(c, resp, info)
222+
}
214223
case RequestModeLlama:
215224
err, usage = openai.OpenaiHandler(c, resp, info)
216225
}

relay/constant/relay_mode.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ const (
4343
RelayModeResponses
4444

4545
RelayModeRealtime
46+
47+
RelayModeGemini
4648
)
4749

4850
func Path2RelayMode(path string) int {
@@ -75,6 +77,8 @@ func Path2RelayMode(path string) int {
7577
relayMode = RelayModeRerank
7678
} else if strings.HasPrefix(path, "/v1/realtime") {
7779
relayMode = RelayModeRealtime
80+
} else if strings.HasPrefix(path, "/v1beta/models") {
81+
relayMode = RelayModeGemini
7882
}
7983
return relayMode
8084
}

0 commit comments

Comments
 (0)