Skip to content

Commit ce6fb95

Browse files
committed
refactor(relay): update channel retrieval to use RelayInfo structure
1 parent 2ac6a5b commit ce6fb95

File tree

4 files changed

+29
-21
lines changed

4 files changed

+29
-21
lines changed

controller/playground.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/QuantumNous/new-api/constant"
1010
"github.com/QuantumNous/new-api/middleware"
1111
"github.com/QuantumNous/new-api/model"
12+
relaycommon "github.com/QuantumNous/new-api/relay/common"
1213
"github.com/QuantumNous/new-api/types"
1314

1415
"github.com/gin-gonic/gin"
@@ -31,8 +32,11 @@ func Playground(c *gin.Context) {
3132
return
3233
}
3334

34-
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
35-
modelName := c.GetString("original_model")
35+
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatOpenAI, nil, nil)
36+
if err != nil {
37+
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
38+
return
39+
}
3640

3741
userId := c.GetInt("id")
3842

@@ -46,11 +50,11 @@ func Playground(c *gin.Context) {
4650

4751
tempToken := &model.Token{
4852
UserId: userId,
49-
Name: fmt.Sprintf("playground-%s", group),
50-
Group: group,
53+
Name: fmt.Sprintf("playground-%s", relayInfo.UsingGroup),
54+
Group: relayInfo.UsingGroup,
5155
}
5256
_ = middleware.SetupContextForToken(c, tempToken)
53-
_, newAPIError = getChannel(c, group, modelName, 0)
57+
_, newAPIError = getChannel(c, relayInfo, 0)
5458
if newAPIError != nil {
5559
return
5660
}

controller/relay.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewA
6464
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
6565

6666
requestId := c.GetString(common.RequestIdKey)
67-
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
68-
originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
67+
//group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
68+
//originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
6969

7070
var (
7171
newAPIError *types.NewAPIError
@@ -158,7 +158,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
158158
}()
159159

160160
for i := 0; i <= common.RetryTimes; i++ {
161-
channel, err := getChannel(c, group, originalModel, i)
161+
channel, err := getChannel(c, relayInfo, i)
162162
if err != nil {
163163
logger.LogError(c, err.Error())
164164
newAPIError = err
@@ -211,7 +211,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
211211
c.Set("use_channel", useChannel)
212212
}
213213

214-
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
214+
func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*model.Channel, *types.NewAPIError) {
215215
if retryCount == 0 {
216216
autoBan := c.GetBool("auto_ban")
217217
autoBanInt := 1
@@ -225,14 +225,18 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
225225
AutoBan: &autoBanInt,
226226
}, nil
227227
}
228-
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
228+
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, info.TokenGroup, info.OriginModelName, retryCount)
229+
230+
info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
231+
229232
if err != nil {
230-
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
233+
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, info.OriginModelName, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
231234
}
232235
if channel == nil {
233-
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
236+
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, info.OriginModelName), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
234237
}
235-
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
238+
239+
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, info.OriginModelName)
236240
if newAPIError != nil {
237241
return nil, newAPIError
238242
}
@@ -392,8 +396,6 @@ func RelayNotFound(c *gin.Context) {
392396
func RelayTask(c *gin.Context) {
393397
retryTimes := common.RetryTimes
394398
channelId := c.GetInt("channel_id")
395-
group := c.GetString("group")
396-
originalModel := c.GetString("original_model")
397399
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
398400
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
399401
if err != nil {
@@ -404,7 +406,7 @@ func RelayTask(c *gin.Context) {
404406
retryTimes = 0
405407
}
406408
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
407-
channel, newAPIError := getChannel(c, group, originalModel, i)
409+
channel, newAPIError := getChannel(c, relayInfo, i)
408410
if newAPIError != nil {
409411
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
410412
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)

relay/common/relay_info.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ type TokenCountMeta struct {
8181
type RelayInfo struct {
8282
TokenId int
8383
TokenKey string
84+
TokenGroup string
8485
UserId int
8586
UsingGroup string // 使用的分组
8687
UserGroup string // 用户所在分组
@@ -400,6 +401,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
400401
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
401402
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
402403
TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
404+
TokenGroup: common.GetContextKeyString(c, constant.ContextKeyTokenGroup),
403405

404406
isFirstResponse: true,
405407
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),

service/channel_select.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ import (
1212
)
1313

1414
// CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
15-
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName string, retry int) (*model.Channel, string, error) {
15+
func CacheGetRandomSatisfiedChannel(c *gin.Context, tokenGroup string, modelName string, retry int) (*model.Channel, string, error) {
1616
var channel *model.Channel
1717
var err error
18-
selectGroup := group
18+
selectGroup := tokenGroup
1919
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
20-
if group == "auto" {
20+
if tokenGroup == "auto" {
2121
if len(setting.GetAutoGroups()) == 0 {
2222
return nil, selectGroup, errors.New("auto groups is not enabled")
2323
}
@@ -49,9 +49,9 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName stri
4949
}
5050
}
5151
} else {
52-
channel, err = model.GetRandomSatisfiedChannel(group, modelName, retry)
52+
channel, err = model.GetRandomSatisfiedChannel(tokenGroup, modelName, retry)
5353
if err != nil {
54-
return nil, group, err
54+
return nil, tokenGroup, err
5555
}
5656
}
5757
return channel, selectGroup, nil

0 commit comments

Comments
 (0)