Skip to content

Commit 84745d5

Browse files
committed
feat: Add ContextKeyLocalCountTokens and update ResponseText2Usage to use context in multiple channels
1 parent ef06472 commit 84745d5

File tree

17 files changed

+96
-127
lines changed

17 files changed

+96
-127
lines changed

constant/context_key.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,7 @@ const (
4646
ContextKeyUsingGroup ContextKey = "group"
4747
ContextKeyUserName ContextKey = "username"
4848

49+
ContextKeyLocalCountTokens ContextKey = "local_count_tokens"
50+
4951
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
5052
)

relay/channel/claude/relay-claude.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
673673
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
674674

675675
if requestMode == RequestModeCompletion {
676-
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
676+
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
677677
} else {
678678
if claudeInfo.Usage.PromptTokens == 0 {
679679
//上游出错
@@ -682,7 +682,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
682682
if common.DebugEnabled {
683683
common.SysLog("claude response usage is not complete, maybe upstream error")
684684
}
685-
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
685+
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
686686
}
687687
}
688688

relay/channel/cloudflare/relay_cloudflare.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
7474
if err := scanner.Err(); err != nil {
7575
logger.LogError(c, "error_scanning_stream_response: "+err.Error())
7676
}
77-
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
77+
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
7878
if info.ShouldIncludeUsage {
7979
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
8080
err := helper.ObjectData(c, response)
@@ -105,7 +105,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
105105
for _, choice := range response.Choices {
106106
responseText += choice.Message.StringContent()
107107
}
108-
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
108+
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
109109
response.Usage = *usage
110110
response.Id = helper.GetResponseID(c)
111111
jsonResponse, err := json.Marshal(response)

relay/channel/cohere/relay-cohere.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
165165
}
166166
})
167167
if usage.PromptTokens == 0 {
168-
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
168+
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
169169
}
170170
return usage, nil
171171
}

relay/channel/coze/relay-coze.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
142142
helper.Done(c)
143143

144144
if usage.TotalTokens == 0 {
145-
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
145+
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
146146
}
147147

148148
return usage, nil

relay/channel/dify/relay-dify.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
246246
})
247247
helper.Done(c)
248248
if usage.TotalTokens == 0 {
249-
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
249+
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
250250
}
251251
usage.CompletionTokens += nodeToken
252252
return usage, nil

relay/channel/gemini/relay-gemini-native.go

Lines changed: 2 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package gemini
33
import (
44
"io"
55
"net/http"
6-
"strings"
76

87
"github.com/QuantumNous/new-api/common"
98
"github.com/QuantumNous/new-api/dto"
@@ -13,8 +12,6 @@ import (
1312
"github.com/QuantumNous/new-api/service"
1413
"github.com/QuantumNous/new-api/types"
1514

16-
"github.com/pkg/errors"
17-
1815
"github.com/gin-gonic/gin"
1916
)
2017

@@ -97,80 +94,15 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
9794
}
9895

9996
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
100-
var usage = &dto.Usage{}
101-
var imageCount int
102-
10397
helper.SetEventStreamHeaders(c)
10498

105-
responseText := strings.Builder{}
106-
107-
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
108-
var geminiResponse dto.GeminiChatResponse
109-
err := common.UnmarshalJsonStr(data, &geminiResponse)
110-
if err != nil {
111-
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
112-
return false
113-
}
114-
115-
// 统计图片数量
116-
for _, candidate := range geminiResponse.Candidates {
117-
for _, part := range candidate.Content.Parts {
118-
if part.InlineData != nil && part.InlineData.MimeType != "" {
119-
imageCount++
120-
}
121-
if part.Text != "" {
122-
responseText.WriteString(part.Text)
123-
}
124-
}
125-
}
126-
127-
// 更新使用量统计
128-
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
129-
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
130-
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
131-
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
132-
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
133-
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
134-
if detail.Modality == "AUDIO" {
135-
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
136-
} else if detail.Modality == "TEXT" {
137-
usage.PromptTokensDetails.TextTokens = detail.TokenCount
138-
}
139-
}
140-
}
141-
99+
return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
142100
// 直接发送 GeminiChatResponse 响应
143-
err = helper.StringData(c, data)
101+
err := helper.StringData(c, data)
144102
if err != nil {
145103
logger.LogError(c, err.Error())
146104
}
147105
info.SendResponseCount++
148106
return true
149107
})
150-
151-
if info.SendResponseCount == 0 {
152-
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
153-
}
154-
155-
if imageCount != 0 {
156-
if usage.CompletionTokens == 0 {
157-
usage.CompletionTokens = imageCount * 258
158-
}
159-
}
160-
161-
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
162-
if usage.CompletionTokens == 0 {
163-
str := responseText.String()
164-
if len(str) > 0 {
165-
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
166-
} else {
167-
// 空补全,不需要使用量
168-
usage = &dto.Usage{}
169-
}
170-
}
171-
172-
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
173-
//helper.Done(c)
174-
175-
return usage, nil
176108
}

relay/channel/gemini/relay-gemini.go

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -954,14 +954,10 @@ func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.Ch
954954
return nil
955955
}
956956

957-
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
958-
// responseText := ""
959-
id := helper.GetResponseID(c)
960-
createAt := common.GetTimestamp()
961-
responseText := strings.Builder{}
957+
func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
962958
var usage = &dto.Usage{}
963959
var imageCount int
964-
finishReason := constant.FinishReasonStop
960+
responseText := strings.Builder{}
965961

966962
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
967963
var geminiResponse dto.GeminiChatResponse
@@ -971,6 +967,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
971967
return false
972968
}
973969

970+
// 统计图片数量
974971
for _, candidate := range geminiResponse.Candidates {
975972
for _, part := range candidate.Content.Parts {
976973
if part.InlineData != nil && part.InlineData.MimeType != "" {
@@ -982,14 +979,10 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
982979
}
983980
}
984981

985-
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
986-
987-
response.Id = id
988-
response.Created = createAt
989-
response.Model = info.UpstreamModelName
982+
// 更新使用量统计
990983
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
991984
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
992-
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
985+
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
993986
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
994987
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
995988
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
@@ -1000,6 +993,45 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
1000993
}
1001994
}
1002995
}
996+
997+
return callback(data, &geminiResponse)
998+
})
999+
1000+
if imageCount != 0 {
1001+
if usage.CompletionTokens == 0 {
1002+
usage.CompletionTokens = imageCount * 1400
1003+
}
1004+
}
1005+
1006+
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
1007+
if usage.TotalTokens > 0 {
1008+
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
1009+
}
1010+
1011+
if usage.CompletionTokens <= 0 {
1012+
str := responseText.String()
1013+
if len(str) > 0 {
1014+
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.PromptTokens)
1015+
} else {
1016+
usage = &dto.Usage{}
1017+
}
1018+
}
1019+
1020+
return usage, nil
1021+
}
1022+
1023+
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
1024+
id := helper.GetResponseID(c)
1025+
createAt := common.GetTimestamp()
1026+
finishReason := constant.FinishReasonStop
1027+
1028+
usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
1029+
response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
1030+
1031+
response.Id = id
1032+
response.Created = createAt
1033+
response.Model = info.UpstreamModelName
1034+
10031035
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
10041036
if info.SendResponseCount == 0 {
10051037
// send first response
@@ -1015,7 +1047,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
10151047
emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
10161048
}
10171049
finishReason = constant.FinishReasonToolCalls
1018-
err = handleStream(c, info, emptyResponse)
1050+
err := handleStream(c, info, emptyResponse)
10191051
if err != nil {
10201052
logger.LogError(c, err.Error())
10211053
}
@@ -1025,14 +1057,14 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
10251057
response.Choices[0].FinishReason = nil
10261058
}
10271059
} else {
1028-
err = handleStream(c, info, emptyResponse)
1060+
err := handleStream(c, info, emptyResponse)
10291061
if err != nil {
10301062
logger.LogError(c, err.Error())
10311063
}
10321064
}
10331065
}
10341066

1035-
err = handleStream(c, info, response)
1067+
err := handleStream(c, info, response)
10361068
if err != nil {
10371069
logger.LogError(c, err.Error())
10381070
}
@@ -1042,40 +1074,15 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
10421074
return true
10431075
})
10441076

1045-
if info.SendResponseCount == 0 {
1046-
// 空补全,报错不计费
1047-
// empty response, throw an error
1048-
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
1049-
}
1050-
1051-
if imageCount != 0 {
1052-
if usage.CompletionTokens == 0 {
1053-
usage.CompletionTokens = imageCount * 258
1054-
}
1055-
}
1056-
1057-
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
1058-
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
1059-
1060-
if usage.CompletionTokens == 0 {
1061-
str := responseText.String()
1062-
if len(str) > 0 {
1063-
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
1064-
} else {
1065-
// 空补全,不需要使用量
1066-
usage = &dto.Usage{}
1067-
}
1077+
if err != nil {
1078+
return usage, err
10681079
}
10691080

10701081
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
1071-
err := handleFinalStream(c, info, response)
1072-
if err != nil {
1073-
common.SysLog("send final response failed: " + err.Error())
1082+
handleErr := handleFinalStream(c, info, response)
1083+
if handleErr != nil {
1084+
common.SysLog("send final response failed: " + handleErr.Error())
10741085
}
1075-
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
1076-
// helper.Done(c)
1077-
//}
1078-
//resp.Body.Close()
10791086
return usage, nil
10801087
}
10811088

relay/channel/openai/relay-openai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
183183
}
184184

185185
if !containStreamUsage {
186-
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
186+
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
187187
usage.CompletionTokens += toolCount * 7
188188
}
189189

relay/channel/palm/adaptor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
8181
if info.IsStream {
8282
var responseText string
8383
err, responseText = palmStreamHandler(c, resp)
84-
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
84+
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
8585
} else {
8686
usage, err = palmHandler(c, info, resp)
8787
}

0 commit comments

Comments
 (0)