Skip to content

Commit c1492be

Browse files
authored
Merge pull request #1954 from QuantumNous/main
main -> alpha
2 parents aab82f2 + 2938246 commit c1492be

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+3732
-244
lines changed

common/endpoint_defaults.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
2323
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
2424
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
2525
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
26+
constant.EndpointTypeEmbeddings: {Path: "/v1/embeddings", Method: "POST"},
2627
}
2728

2829
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在

constant/endpoint_type.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ const (
99
EndpointTypeGemini EndpointType = "gemini"
1010
EndpointTypeJinaRerank EndpointType = "jina-rerank"
1111
EndpointTypeImageGeneration EndpointType = "image-generation"
12+
EndpointTypeEmbeddings EndpointType = "embeddings"
1213
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
1314
//EndpointTypeSuno EndpointType = "suno-proxy"
1415
//EndpointTypeKling EndpointType = "kling"

controller/channel-test.go

Lines changed: 195 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type testResult struct {
3838
newAPIError *types.NewAPIError
3939
}
4040

41-
func testChannel(channel *model.Channel, testModel string) testResult {
41+
func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
4242
tik := time.Now()
4343
if channel.Type == constant.ChannelTypeMidjourney {
4444
return testResult{
@@ -81,18 +81,26 @@ func testChannel(channel *model.Channel, testModel string) testResult {
8181

8282
requestPath := "/v1/chat/completions"
8383

84-
// 先判断是否为 Embedding 模型
85-
if strings.Contains(strings.ToLower(testModel), "embedding") ||
86-
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
87-
strings.Contains(testModel, "bge-") || // bge 系列模型
88-
strings.Contains(testModel, "embed") ||
89-
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
90-
requestPath = "/v1/embeddings" // 修改请求路径
91-
}
84+
// 如果指定了端点类型,使用指定的端点类型
85+
if endpointType != "" {
86+
if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok {
87+
requestPath = endpointInfo.Path
88+
}
89+
} else {
90+
// 如果没有指定端点类型,使用原有的自动检测逻辑
91+
// 先判断是否为 Embedding 模型
92+
if strings.Contains(strings.ToLower(testModel), "embedding") ||
93+
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
94+
strings.Contains(testModel, "bge-") || // bge 系列模型
95+
strings.Contains(testModel, "embed") ||
96+
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
97+
requestPath = "/v1/embeddings" // 修改请求路径
98+
}
9299

93-
// VolcEngine 图像生成模型
94-
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
95-
requestPath = "/v1/images/generations"
100+
// VolcEngine 图像生成模型
101+
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
102+
requestPath = "/v1/images/generations"
103+
}
96104
}
97105

98106
c.Request = &http.Request{
@@ -114,21 +122,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
114122
}
115123
}
116124

117-
// 重新检查模型类型并更新请求路径
118-
if strings.Contains(strings.ToLower(testModel), "embedding") ||
119-
strings.HasPrefix(testModel, "m3e") ||
120-
strings.Contains(testModel, "bge-") ||
121-
strings.Contains(testModel, "embed") ||
122-
channel.Type == constant.ChannelTypeMokaAI {
123-
requestPath = "/v1/embeddings"
124-
c.Request.URL.Path = requestPath
125-
}
126-
127-
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
128-
requestPath = "/v1/images/generations"
129-
c.Request.URL.Path = requestPath
130-
}
131-
132125
cache, err := model.GetUserCache(1)
133126
if err != nil {
134127
return testResult{
@@ -153,17 +146,54 @@ func testChannel(channel *model.Channel, testModel string) testResult {
153146
newAPIError: newAPIError,
154147
}
155148
}
156-
request := buildTestRequest(testModel)
157149

158-
// Determine relay format based on request path
159-
relayFormat := types.RelayFormatOpenAI
160-
if c.Request.URL.Path == "/v1/embeddings" {
161-
relayFormat = types.RelayFormatEmbedding
162-
}
163-
if c.Request.URL.Path == "/v1/images/generations" {
164-
relayFormat = types.RelayFormatOpenAIImage
150+
// Determine relay format based on endpoint type or request path
151+
var relayFormat types.RelayFormat
152+
if endpointType != "" {
153+
// 根据指定的端点类型设置 relayFormat
154+
switch constant.EndpointType(endpointType) {
155+
case constant.EndpointTypeOpenAI:
156+
relayFormat = types.RelayFormatOpenAI
157+
case constant.EndpointTypeOpenAIResponse:
158+
relayFormat = types.RelayFormatOpenAIResponses
159+
case constant.EndpointTypeAnthropic:
160+
relayFormat = types.RelayFormatClaude
161+
case constant.EndpointTypeGemini:
162+
relayFormat = types.RelayFormatGemini
163+
case constant.EndpointTypeJinaRerank:
164+
relayFormat = types.RelayFormatRerank
165+
case constant.EndpointTypeImageGeneration:
166+
relayFormat = types.RelayFormatOpenAIImage
167+
case constant.EndpointTypeEmbeddings:
168+
relayFormat = types.RelayFormatEmbedding
169+
default:
170+
relayFormat = types.RelayFormatOpenAI
171+
}
172+
} else {
173+
// 根据请求路径自动检测
174+
relayFormat = types.RelayFormatOpenAI
175+
if c.Request.URL.Path == "/v1/embeddings" {
176+
relayFormat = types.RelayFormatEmbedding
177+
}
178+
if c.Request.URL.Path == "/v1/images/generations" {
179+
relayFormat = types.RelayFormatOpenAIImage
180+
}
181+
if c.Request.URL.Path == "/v1/messages" {
182+
relayFormat = types.RelayFormatClaude
183+
}
184+
if strings.Contains(c.Request.URL.Path, "/v1beta/models") {
185+
relayFormat = types.RelayFormatGemini
186+
}
187+
if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" {
188+
relayFormat = types.RelayFormatRerank
189+
}
190+
if c.Request.URL.Path == "/v1/responses" {
191+
relayFormat = types.RelayFormatOpenAIResponses
192+
}
165193
}
166194

195+
request := buildTestRequest(testModel, endpointType)
196+
167197
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
168198

169199
if err != nil {
@@ -186,7 +216,8 @@ func testChannel(channel *model.Channel, testModel string) testResult {
186216
}
187217

188218
testModel = info.UpstreamModelName
189-
request.Model = testModel
219+
// 更新请求中的模型名称
220+
request.SetModelName(testModel)
190221

191222
apiType, _ := common.ChannelType2APIType(channel.Type)
192223
adaptor := relay.GetAdaptor(apiType)
@@ -216,33 +247,62 @@ func testChannel(channel *model.Channel, testModel string) testResult {
216247

217248
var convertedRequest any
218249
// 根据 RelayMode 选择正确的转换函数
219-
if info.RelayMode == relayconstant.RelayModeEmbeddings {
220-
// 创建一个 EmbeddingRequest
221-
embeddingRequest := dto.EmbeddingRequest{
222-
Input: request.Input,
223-
Model: request.Model,
224-
}
225-
// 调用专门用于 Embedding 的转换函数
226-
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
227-
} else if info.RelayMode == relayconstant.RelayModeImagesGenerations {
228-
// 创建一个 ImageRequest
229-
prompt := "cat"
230-
if request.Prompt != nil {
231-
if promptStr, ok := request.Prompt.(string); ok && promptStr != "" {
232-
prompt = promptStr
250+
switch info.RelayMode {
251+
case relayconstant.RelayModeEmbeddings:
252+
// Embedding 请求 - request 已经是正确的类型
253+
if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok {
254+
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq)
255+
} else {
256+
return testResult{
257+
context: c,
258+
localErr: errors.New("invalid embedding request type"),
259+
newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed),
233260
}
234261
}
235-
imageRequest := dto.ImageRequest{
236-
Prompt: prompt,
237-
Model: request.Model,
238-
N: uint(request.N),
239-
Size: request.Size,
262+
case relayconstant.RelayModeImagesGenerations:
263+
// 图像生成请求 - request 已经是正确的类型
264+
if imageReq, ok := request.(*dto.ImageRequest); ok {
265+
convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq)
266+
} else {
267+
return testResult{
268+
context: c,
269+
localErr: errors.New("invalid image request type"),
270+
newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed),
271+
}
272+
}
273+
case relayconstant.RelayModeRerank:
274+
// Rerank 请求 - request 已经是正确的类型
275+
if rerankReq, ok := request.(*dto.RerankRequest); ok {
276+
convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq)
277+
} else {
278+
return testResult{
279+
context: c,
280+
localErr: errors.New("invalid rerank request type"),
281+
newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed),
282+
}
283+
}
284+
case relayconstant.RelayModeResponses:
285+
// Response 请求 - request 已经是正确的类型
286+
if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok {
287+
convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq)
288+
} else {
289+
return testResult{
290+
context: c,
291+
localErr: errors.New("invalid response request type"),
292+
newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed),
293+
}
294+
}
295+
default:
296+
// Chat/Completion 等其他请求类型
297+
if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok {
298+
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq)
299+
} else {
300+
return testResult{
301+
context: c,
302+
localErr: errors.New("invalid general request type"),
303+
newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed),
304+
}
240305
}
241-
// 调用专门用于图像生成的转换函数
242-
convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest)
243-
} else {
244-
// 对其他所有请求类型(如 Chat),保持原有逻辑
245-
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
246306
}
247307

248308
if err != nil {
@@ -345,22 +405,82 @@ func testChannel(channel *model.Channel, testModel string) testResult {
345405
}
346406
}
347407

348-
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
349-
testRequest := &dto.GeneralOpenAIRequest{
350-
Model: "", // this will be set later
351-
Stream: false,
408+
func buildTestRequest(model string, endpointType string) dto.Request {
409+
// 根据端点类型构建不同的测试请求
410+
if endpointType != "" {
411+
switch constant.EndpointType(endpointType) {
412+
case constant.EndpointTypeEmbeddings:
413+
// 返回 EmbeddingRequest
414+
return &dto.EmbeddingRequest{
415+
Model: model,
416+
Input: []any{"hello world"},
417+
}
418+
case constant.EndpointTypeImageGeneration:
419+
// 返回 ImageRequest
420+
return &dto.ImageRequest{
421+
Model: model,
422+
Prompt: "a cute cat",
423+
N: 1,
424+
Size: "1024x1024",
425+
}
426+
case constant.EndpointTypeJinaRerank:
427+
// 返回 RerankRequest
428+
return &dto.RerankRequest{
429+
Model: model,
430+
Query: "What is Deep Learning?",
431+
Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
432+
TopN: 2,
433+
}
434+
case constant.EndpointTypeOpenAIResponse:
435+
// 返回 OpenAIResponsesRequest
436+
return &dto.OpenAIResponsesRequest{
437+
Model: model,
438+
Input: json.RawMessage("\"hi\""),
439+
}
440+
case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
441+
// 返回 GeneralOpenAIRequest
442+
maxTokens := uint(10)
443+
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
444+
maxTokens = 3000
445+
}
446+
return &dto.GeneralOpenAIRequest{
447+
Model: model,
448+
Stream: false,
449+
Messages: []dto.Message{
450+
{
451+
Role: "user",
452+
Content: "hi",
453+
},
454+
},
455+
MaxTokens: maxTokens,
456+
}
457+
}
352458
}
353459

460+
// 自动检测逻辑(保持原有行为)
354461
// 先判断是否为 Embedding 模型
355-
if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
356-
strings.HasPrefix(model, "m3e") || // m3e 系列模型
462+
if strings.Contains(strings.ToLower(model), "embedding") ||
463+
strings.HasPrefix(model, "m3e") ||
357464
strings.Contains(model, "bge-") {
358-
testRequest.Model = model
359-
// Embedding 请求
360-
testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
361-
return testRequest
465+
// 返回 EmbeddingRequest
466+
return &dto.EmbeddingRequest{
467+
Model: model,
468+
Input: []any{"hello world"},
469+
}
362470
}
363-
// 并非Embedding 模型
471+
472+
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
473+
testRequest := &dto.GeneralOpenAIRequest{
474+
Model: model,
475+
Stream: false,
476+
Messages: []dto.Message{
477+
{
478+
Role: "user",
479+
Content: "hi",
480+
},
481+
},
482+
}
483+
364484
if strings.HasPrefix(model, "o") {
365485
testRequest.MaxCompletionTokens = 10
366486
} else if strings.Contains(model, "thinking") {
@@ -373,12 +493,6 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
373493
testRequest.MaxTokens = 10
374494
}
375495

376-
testMessage := dto.Message{
377-
Role: "user",
378-
Content: "hi",
379-
}
380-
testRequest.Model = model
381-
testRequest.Messages = append(testRequest.Messages, testMessage)
382496
return testRequest
383497
}
384498

@@ -402,8 +516,9 @@ func TestChannel(c *gin.Context) {
402516
// }
403517
//}()
404518
testModel := c.Query("model")
519+
endpointType := c.Query("endpoint_type")
405520
tik := time.Now()
406-
result := testChannel(channel, testModel)
521+
result := testChannel(channel, testModel, endpointType)
407522
if result.localErr != nil {
408523
c.JSON(http.StatusOK, gin.H{
409524
"success": false,
@@ -429,7 +544,6 @@ func TestChannel(c *gin.Context) {
429544
"message": "",
430545
"time": consumedTime,
431546
})
432-
return
433547
}
434548

435549
var testAllChannelsLock sync.Mutex
@@ -463,7 +577,7 @@ func testAllChannels(notify bool) error {
463577
for _, channel := range channels {
464578
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
465579
tik := time.Now()
466-
result := testChannel(channel, "")
580+
result := testChannel(channel, "", "")
467581
tok := time.Now()
468582
milliseconds := tok.Sub(tik).Milliseconds()
469583

@@ -477,7 +591,7 @@ func testAllChannels(notify bool) error {
477591
// 当错误检查通过,才检查响应时间
478592
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
479593
if milliseconds > disableThreshold {
480-
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
594+
err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
481595
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
482596
shouldBanChannel = true
483597
}
@@ -514,7 +628,6 @@ func TestAllChannels(c *gin.Context) {
514628
"success": true,
515629
"message": "",
516630
})
517-
return
518631
}
519632

520633
var autoTestChannelsOnce sync.Once

0 commit comments

Comments
 (0)