Skip to content

Commit 5018945

Browse files
authored
Merge pull request #1171 from RedwindA/feat/ali-rerank
feat: ali rerank
2 parents 7e9bd35 + 49e77fb commit 5018945

File tree

4 files changed

+116
-1
lines changed

4 files changed

+116
-1
lines changed

relay/channel/ali/adaptor.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
3131
switch info.RelayMode {
3232
case constant.RelayModeEmbeddings:
3333
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
34+
case constant.RelayModeRerank:
35+
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
3436
case constant.RelayModeImagesGenerations:
3537
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
3638
case constant.RelayModeCompletions:
@@ -76,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
7678
}
7779

7880
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
79-
return nil, errors.New("not implemented")
81+
return ConvertRerankRequest(request), nil
8082
}
8183

8284
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
@@ -103,6 +105,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
103105
err, usage = aliImageHandler(c, resp, info)
104106
case constant.RelayModeEmbeddings:
105107
err, usage = aliEmbeddingHandler(c, resp)
108+
case constant.RelayModeRerank:
109+
err, usage = RerankHandler(c, resp, info)
106110
default:
107111
if info.IsStream {
108112
err, usage = openai.OaiStreamHandler(c, resp, info)

relay/channel/ali/constants.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ var ModelList = []string{
88
"qwq-32b",
99
"qwen3-235b-a22b",
1010
"text-embedding-v1",
11+
"gte-rerank-v2",
1112
}
1213

1314
var ChannelName = "ali"

relay/channel/ali/dto.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package ali
22

3+
import "one-api/dto"
4+
35
type AliMessage struct {
46
Content string `json:"content"`
57
Role string `json:"role"`
@@ -97,3 +99,28 @@ type AliImageRequest struct {
9799
} `json:"parameters,omitempty"`
98100
ResponseFormat string `json:"response_format,omitempty"`
99101
}
102+
103+
type AliRerankParameters struct {
104+
TopN *int `json:"top_n,omitempty"`
105+
ReturnDocuments *bool `json:"return_documents,omitempty"`
106+
}
107+
108+
type AliRerankInput struct {
109+
Query string `json:"query"`
110+
Documents []any `json:"documents"`
111+
}
112+
113+
type AliRerankRequest struct {
114+
Model string `json:"model"`
115+
Input AliRerankInput `json:"input"`
116+
Parameters AliRerankParameters `json:"parameters,omitempty"`
117+
}
118+
119+
type AliRerankResponse struct {
120+
Output struct {
121+
Results []dto.RerankResponseResult `json:"results"`
122+
} `json:"output"`
123+
Usage AliUsage `json:"usage"`
124+
RequestId string `json:"request_id"`
125+
AliError
126+
}

relay/channel/ali/rerank.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package ali
2+
3+
import (
4+
"encoding/json"
5+
"io"
6+
"net/http"
7+
"one-api/dto"
8+
relaycommon "one-api/relay/common"
9+
"one-api/service"
10+
11+
"github.com/gin-gonic/gin"
12+
)
13+
14+
func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
15+
returnDocuments := request.ReturnDocuments
16+
if returnDocuments == nil {
17+
t := true
18+
returnDocuments = &t
19+
}
20+
return &AliRerankRequest{
21+
Model: request.Model,
22+
Input: AliRerankInput{
23+
Query: request.Query,
24+
Documents: request.Documents,
25+
},
26+
Parameters: AliRerankParameters{
27+
TopN: &request.TopN,
28+
ReturnDocuments: returnDocuments,
29+
},
30+
}
31+
}
32+
33+
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
34+
responseBody, err := io.ReadAll(resp.Body)
35+
if err != nil {
36+
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
37+
}
38+
err = resp.Body.Close()
39+
if err != nil {
40+
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
41+
}
42+
43+
var aliResponse AliRerankResponse
44+
err = json.Unmarshal(responseBody, &aliResponse)
45+
if err != nil {
46+
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
47+
}
48+
49+
if aliResponse.Code != "" {
50+
return &dto.OpenAIErrorWithStatusCode{
51+
Error: dto.OpenAIError{
52+
Message: aliResponse.Message,
53+
Type: aliResponse.Code,
54+
Param: aliResponse.RequestId,
55+
Code: aliResponse.Code,
56+
},
57+
StatusCode: resp.StatusCode,
58+
}, nil
59+
}
60+
61+
usage := dto.Usage{
62+
PromptTokens: aliResponse.Usage.TotalTokens,
63+
CompletionTokens: 0,
64+
TotalTokens: aliResponse.Usage.TotalTokens,
65+
}
66+
rerankResponse := dto.RerankResponse{
67+
Results: aliResponse.Output.Results,
68+
Usage: usage,
69+
}
70+
71+
jsonResponse, err := json.Marshal(rerankResponse)
72+
if err != nil {
73+
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
74+
}
75+
c.Writer.Header().Set("Content-Type", "application/json")
76+
c.Writer.WriteHeader(resp.StatusCode)
77+
_, err = c.Writer.Write(jsonResponse)
78+
if err != nil {
79+
return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
80+
}
81+
82+
return nil, &usage
83+
}

0 commit comments

Comments
 (0)