@@ -38,7 +38,7 @@ type testResult struct {
3838 newAPIError * types.NewAPIError
3939}
4040
41- func testChannel (channel * model.Channel , testModel string ) testResult {
41+ func testChannel (channel * model.Channel , testModel string , endpointType string ) testResult {
4242 tik := time .Now ()
4343 if channel .Type == constant .ChannelTypeMidjourney {
4444 return testResult {
@@ -81,18 +81,26 @@ func testChannel(channel *model.Channel, testModel string) testResult {
8181
8282 requestPath := "/v1/chat/completions"
8383
84- // 先判断是否为 Embedding 模型
85- if strings .Contains (strings .ToLower (testModel ), "embedding" ) ||
86- strings .HasPrefix (testModel , "m3e" ) || // m3e 系列模型
87- strings .Contains (testModel , "bge-" ) || // bge 系列模型
88- strings .Contains (testModel , "embed" ) ||
89- channel .Type == constant .ChannelTypeMokaAI { // 其他 embedding 模型
90- requestPath = "/v1/embeddings" // 修改请求路径
91- }
84+ // 如果指定了端点类型,使用指定的端点类型
85+ if endpointType != "" {
86+ if endpointInfo , ok := common .GetDefaultEndpointInfo (constant .EndpointType (endpointType )); ok {
87+ requestPath = endpointInfo .Path
88+ }
89+ } else {
90+ // 如果没有指定端点类型,使用原有的自动检测逻辑
91+ // 先判断是否为 Embedding 模型
92+ if strings .Contains (strings .ToLower (testModel ), "embedding" ) ||
93+ strings .HasPrefix (testModel , "m3e" ) || // m3e 系列模型
94+ strings .Contains (testModel , "bge-" ) || // bge 系列模型
95+ strings .Contains (testModel , "embed" ) ||
96+ channel .Type == constant .ChannelTypeMokaAI { // 其他 embedding 模型
97+ requestPath = "/v1/embeddings" // 修改请求路径
98+ }
9299
93- // VolcEngine 图像生成模型
94- if channel .Type == constant .ChannelTypeVolcEngine && strings .Contains (testModel , "seedream" ) {
95- requestPath = "/v1/images/generations"
100+ // VolcEngine 图像生成模型
101+ if channel .Type == constant .ChannelTypeVolcEngine && strings .Contains (testModel , "seedream" ) {
102+ requestPath = "/v1/images/generations"
103+ }
96104 }
97105
98106 c .Request = & http.Request {
@@ -114,21 +122,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
114122 }
115123 }
116124
117- // 重新检查模型类型并更新请求路径
118- if strings .Contains (strings .ToLower (testModel ), "embedding" ) ||
119- strings .HasPrefix (testModel , "m3e" ) ||
120- strings .Contains (testModel , "bge-" ) ||
121- strings .Contains (testModel , "embed" ) ||
122- channel .Type == constant .ChannelTypeMokaAI {
123- requestPath = "/v1/embeddings"
124- c .Request .URL .Path = requestPath
125- }
126-
127- if channel .Type == constant .ChannelTypeVolcEngine && strings .Contains (testModel , "seedream" ) {
128- requestPath = "/v1/images/generations"
129- c .Request .URL .Path = requestPath
130- }
131-
132125 cache , err := model .GetUserCache (1 )
133126 if err != nil {
134127 return testResult {
@@ -153,17 +146,54 @@ func testChannel(channel *model.Channel, testModel string) testResult {
153146 newAPIError : newAPIError ,
154147 }
155148 }
156- request := buildTestRequest (testModel )
157149
158- // Determine relay format based on request path
159- relayFormat := types .RelayFormatOpenAI
160- if c .Request .URL .Path == "/v1/embeddings" {
161- relayFormat = types .RelayFormatEmbedding
162- }
163- if c .Request .URL .Path == "/v1/images/generations" {
164- relayFormat = types .RelayFormatOpenAIImage
150+ // Determine relay format based on endpoint type or request path
151+ var relayFormat types.RelayFormat
152+ if endpointType != "" {
153+ // 根据指定的端点类型设置 relayFormat
154+ switch constant .EndpointType (endpointType ) {
155+ case constant .EndpointTypeOpenAI :
156+ relayFormat = types .RelayFormatOpenAI
157+ case constant .EndpointTypeOpenAIResponse :
158+ relayFormat = types .RelayFormatOpenAIResponses
159+ case constant .EndpointTypeAnthropic :
160+ relayFormat = types .RelayFormatClaude
161+ case constant .EndpointTypeGemini :
162+ relayFormat = types .RelayFormatGemini
163+ case constant .EndpointTypeJinaRerank :
164+ relayFormat = types .RelayFormatRerank
165+ case constant .EndpointTypeImageGeneration :
166+ relayFormat = types .RelayFormatOpenAIImage
167+ case constant .EndpointTypeEmbeddings :
168+ relayFormat = types .RelayFormatEmbedding
169+ default :
170+ relayFormat = types .RelayFormatOpenAI
171+ }
172+ } else {
173+ // 根据请求路径自动检测
174+ relayFormat = types .RelayFormatOpenAI
175+ if c .Request .URL .Path == "/v1/embeddings" {
176+ relayFormat = types .RelayFormatEmbedding
177+ }
178+ if c .Request .URL .Path == "/v1/images/generations" {
179+ relayFormat = types .RelayFormatOpenAIImage
180+ }
181+ if c .Request .URL .Path == "/v1/messages" {
182+ relayFormat = types .RelayFormatClaude
183+ }
184+ if strings .Contains (c .Request .URL .Path , "/v1beta/models" ) {
185+ relayFormat = types .RelayFormatGemini
186+ }
187+ if c .Request .URL .Path == "/v1/rerank" || c .Request .URL .Path == "/rerank" {
188+ relayFormat = types .RelayFormatRerank
189+ }
190+ if c .Request .URL .Path == "/v1/responses" {
191+ relayFormat = types .RelayFormatOpenAIResponses
192+ }
165193 }
166194
195+ request := buildTestRequest (testModel , endpointType )
196+
167197 info , err := relaycommon .GenRelayInfo (c , relayFormat , request , nil )
168198
169199 if err != nil {
@@ -186,7 +216,8 @@ func testChannel(channel *model.Channel, testModel string) testResult {
186216 }
187217
188218 testModel = info .UpstreamModelName
189- request .Model = testModel
219+ // 更新请求中的模型名称
220+ request .SetModelName (testModel )
190221
191222 apiType , _ := common .ChannelType2APIType (channel .Type )
192223 adaptor := relay .GetAdaptor (apiType )
@@ -216,33 +247,62 @@ func testChannel(channel *model.Channel, testModel string) testResult {
216247
217248 var convertedRequest any
218249 // 根据 RelayMode 选择正确的转换函数
219- if info .RelayMode == relayconstant .RelayModeEmbeddings {
220- // 创建一个 EmbeddingRequest
221- embeddingRequest := dto.EmbeddingRequest {
222- Input : request .Input ,
223- Model : request .Model ,
224- }
225- // 调用专门用于 Embedding 的转换函数
226- convertedRequest , err = adaptor .ConvertEmbeddingRequest (c , info , embeddingRequest )
227- } else if info .RelayMode == relayconstant .RelayModeImagesGenerations {
228- // 创建一个 ImageRequest
229- prompt := "cat"
230- if request .Prompt != nil {
231- if promptStr , ok := request .Prompt .(string ); ok && promptStr != "" {
232- prompt = promptStr
250+ switch info .RelayMode {
251+ case relayconstant .RelayModeEmbeddings :
252+ // Embedding 请求 - request 已经是正确的类型
253+ if embeddingReq , ok := request .(* dto.EmbeddingRequest ); ok {
254+ convertedRequest , err = adaptor .ConvertEmbeddingRequest (c , info , * embeddingReq )
255+ } else {
256+ return testResult {
257+ context : c ,
258+ localErr : errors .New ("invalid embedding request type" ),
259+ newAPIError : types .NewError (errors .New ("invalid embedding request type" ), types .ErrorCodeConvertRequestFailed ),
233260 }
234261 }
235- imageRequest := dto.ImageRequest {
236- Prompt : prompt ,
237- Model : request .Model ,
238- N : uint (request .N ),
239- Size : request .Size ,
262+ case relayconstant .RelayModeImagesGenerations :
263+ // 图像生成请求 - request 已经是正确的类型
264+ if imageReq , ok := request .(* dto.ImageRequest ); ok {
265+ convertedRequest , err = adaptor .ConvertImageRequest (c , info , * imageReq )
266+ } else {
267+ return testResult {
268+ context : c ,
269+ localErr : errors .New ("invalid image request type" ),
270+ newAPIError : types .NewError (errors .New ("invalid image request type" ), types .ErrorCodeConvertRequestFailed ),
271+ }
272+ }
273+ case relayconstant .RelayModeRerank :
274+ // Rerank 请求 - request 已经是正确的类型
275+ if rerankReq , ok := request .(* dto.RerankRequest ); ok {
276+ convertedRequest , err = adaptor .ConvertRerankRequest (c , info .RelayMode , * rerankReq )
277+ } else {
278+ return testResult {
279+ context : c ,
280+ localErr : errors .New ("invalid rerank request type" ),
281+ newAPIError : types .NewError (errors .New ("invalid rerank request type" ), types .ErrorCodeConvertRequestFailed ),
282+ }
283+ }
284+ case relayconstant .RelayModeResponses :
285+ // Response 请求 - request 已经是正确的类型
286+ if responseReq , ok := request .(* dto.OpenAIResponsesRequest ); ok {
287+ convertedRequest , err = adaptor .ConvertOpenAIResponsesRequest (c , info , * responseReq )
288+ } else {
289+ return testResult {
290+ context : c ,
291+ localErr : errors .New ("invalid response request type" ),
292+ newAPIError : types .NewError (errors .New ("invalid response request type" ), types .ErrorCodeConvertRequestFailed ),
293+ }
294+ }
295+ default :
296+ // Chat/Completion 等其他请求类型
297+ if generalReq , ok := request .(* dto.GeneralOpenAIRequest ); ok {
298+ convertedRequest , err = adaptor .ConvertOpenAIRequest (c , info , generalReq )
299+ } else {
300+ return testResult {
301+ context : c ,
302+ localErr : errors .New ("invalid general request type" ),
303+ newAPIError : types .NewError (errors .New ("invalid general request type" ), types .ErrorCodeConvertRequestFailed ),
304+ }
240305 }
241- // 调用专门用于图像生成的转换函数
242- convertedRequest , err = adaptor .ConvertImageRequest (c , info , imageRequest )
243- } else {
244- // 对其他所有请求类型(如 Chat),保持原有逻辑
245- convertedRequest , err = adaptor .ConvertOpenAIRequest (c , info , request )
246306 }
247307
248308 if err != nil {
@@ -345,22 +405,82 @@ func testChannel(channel *model.Channel, testModel string) testResult {
345405 }
346406}
347407
348- func buildTestRequest (model string ) * dto.GeneralOpenAIRequest {
349- testRequest := & dto.GeneralOpenAIRequest {
350- Model : "" , // this will be set later
351- Stream : false ,
408+ func buildTestRequest (model string , endpointType string ) dto.Request {
409+ // 根据端点类型构建不同的测试请求
410+ if endpointType != "" {
411+ switch constant .EndpointType (endpointType ) {
412+ case constant .EndpointTypeEmbeddings :
413+ // 返回 EmbeddingRequest
414+ return & dto.EmbeddingRequest {
415+ Model : model ,
416+ Input : []any {"hello world" },
417+ }
418+ case constant .EndpointTypeImageGeneration :
419+ // 返回 ImageRequest
420+ return & dto.ImageRequest {
421+ Model : model ,
422+ Prompt : "a cute cat" ,
423+ N : 1 ,
424+ Size : "1024x1024" ,
425+ }
426+ case constant .EndpointTypeJinaRerank :
427+ // 返回 RerankRequest
428+ return & dto.RerankRequest {
429+ Model : model ,
430+ Query : "What is Deep Learning?" ,
431+ Documents : []any {"Deep Learning is a subset of machine learning." , "Machine learning is a field of artificial intelligence." },
432+ TopN : 2 ,
433+ }
434+ case constant .EndpointTypeOpenAIResponse :
435+ // 返回 OpenAIResponsesRequest
436+ return & dto.OpenAIResponsesRequest {
437+ Model : model ,
438+ Input : json .RawMessage ("\" hi\" " ),
439+ }
440+ case constant .EndpointTypeAnthropic , constant .EndpointTypeGemini , constant .EndpointTypeOpenAI :
441+ // 返回 GeneralOpenAIRequest
442+ maxTokens := uint (10 )
443+ if constant .EndpointType (endpointType ) == constant .EndpointTypeGemini {
444+ maxTokens = 3000
445+ }
446+ return & dto.GeneralOpenAIRequest {
447+ Model : model ,
448+ Stream : false ,
449+ Messages : []dto.Message {
450+ {
451+ Role : "user" ,
452+ Content : "hi" ,
453+ },
454+ },
455+ MaxTokens : maxTokens ,
456+ }
457+ }
352458 }
353459
460+ // 自动检测逻辑(保持原有行为)
354461 // 先判断是否为 Embedding 模型
355- if strings .Contains (strings .ToLower (model ), "embedding" ) || // 其他 embedding 模型
356- strings .HasPrefix (model , "m3e" ) || // m3e 系列模型
462+ if strings .Contains (strings .ToLower (model ), "embedding" ) ||
463+ strings .HasPrefix (model , "m3e" ) ||
357464 strings .Contains (model , "bge-" ) {
358- testRequest .Model = model
359- // Embedding 请求
360- testRequest .Input = []any {"hello world" } // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
361- return testRequest
465+ // 返回 EmbeddingRequest
466+ return & dto.EmbeddingRequest {
467+ Model : model ,
468+ Input : []any {"hello world" },
469+ }
362470 }
363- // 并非Embedding 模型
471+
472+ // Chat/Completion 请求 - 返回 GeneralOpenAIRequest
473+ testRequest := & dto.GeneralOpenAIRequest {
474+ Model : model ,
475+ Stream : false ,
476+ Messages : []dto.Message {
477+ {
478+ Role : "user" ,
479+ Content : "hi" ,
480+ },
481+ },
482+ }
483+
364484 if strings .HasPrefix (model , "o" ) {
365485 testRequest .MaxCompletionTokens = 10
366486 } else if strings .Contains (model , "thinking" ) {
@@ -373,12 +493,6 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
373493 testRequest .MaxTokens = 10
374494 }
375495
376- testMessage := dto.Message {
377- Role : "user" ,
378- Content : "hi" ,
379- }
380- testRequest .Model = model
381- testRequest .Messages = append (testRequest .Messages , testMessage )
382496 return testRequest
383497}
384498
@@ -402,8 +516,9 @@ func TestChannel(c *gin.Context) {
402516 // }
403517 //}()
404518 testModel := c .Query ("model" )
519+ endpointType := c .Query ("endpoint_type" )
405520 tik := time .Now ()
406- result := testChannel (channel , testModel )
521+ result := testChannel (channel , testModel , endpointType )
407522 if result .localErr != nil {
408523 c .JSON (http .StatusOK , gin.H {
409524 "success" : false ,
@@ -429,7 +544,6 @@ func TestChannel(c *gin.Context) {
429544 "message" : "" ,
430545 "time" : consumedTime ,
431546 })
432- return
433547}
434548
435549var testAllChannelsLock sync.Mutex
@@ -463,7 +577,7 @@ func testAllChannels(notify bool) error {
463577 for _ , channel := range channels {
464578 isChannelEnabled := channel .Status == common .ChannelStatusEnabled
465579 tik := time .Now ()
466- result := testChannel (channel , "" )
580+ result := testChannel (channel , "" , "" )
467581 tok := time .Now ()
468582 milliseconds := tok .Sub (tik ).Milliseconds ()
469583
@@ -477,7 +591,7 @@ func testAllChannels(notify bool) error {
477591 // 当错误检查通过,才检查响应时间
478592 if common .AutomaticDisableChannelEnabled && ! shouldBanChannel {
479593 if milliseconds > disableThreshold {
480- err := errors . New ( fmt .Sprintf ("响应时间 %.2fs 超过阈值 %.2fs" , float64 (milliseconds )/ 1000.0 , float64 (disableThreshold )/ 1000.0 ) )
594+ err := fmt .Errorf ("响应时间 %.2fs 超过阈值 %.2fs" , float64 (milliseconds )/ 1000.0 , float64 (disableThreshold )/ 1000.0 )
481595 newAPIError = types .NewOpenAIError (err , types .ErrorCodeChannelResponseTimeExceeded , http .StatusRequestTimeout )
482596 shouldBanChannel = true
483597 }
@@ -514,7 +628,6 @@ func TestAllChannels(c *gin.Context) {
514628 "success" : true ,
515629 "message" : "" ,
516630 })
517- return
518631}
519632
520633var autoTestChannelsOnce sync.Once
0 commit comments