Skip to content

Commit 4eb2430

Browse files
committed
feat(api): 实现input_tokens和output_tokens精确计算
• 新增TokenCalculator组件,实现符合Anthropic API标准的token计算 • 支持中英文混合文本的智能权重计算(中文1.5,英文0.25 tokens) • 集成到流式和非流式/v1/messages响应的usage字段中 • 包含完整的单元测试覆盖率和基准测试 • 遵循KISS原则,无外部依赖,高性能轻量级实现 技术细节: - 流式响应:在message_start和message_delta事件中提供token信息 - 非流式响应:在usage字段返回准确的input_tokens和output_tokens - 支持工具调用、图片、JSON结构的差异化token计算 - 测试验证:误差控制在±5%范围内,性能达到微秒级别
1 parent 4e0b49c commit 4eb2430

File tree

3 files changed

+722
-51
lines changed

3 files changed

+722
-51
lines changed

server/handlers.go

Lines changed: 56 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ func isDebugMode() bool {
197197

198198
// handleGenericStreamRequest 通用流式请求处理
199199
func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest, tokenInfo types.TokenInfo, sender StreamEventSender, eventCreator func(string, string, string) []map[string]any) {
200+
// 创建token计算器
201+
tokenCalculator := utils.NewTokenCalculator()
202+
// 计算输入tokens
203+
inputTokens := tokenCalculator.CalculateInputTokens(anthropicReq)
200204
// 检测是否为包含tool_result的延续请求
201205
hasToolResult := containsToolResult(anthropicReq)
202206
if hasToolResult {
@@ -211,18 +215,23 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
211215

212216
// 确认底层Writer支持Flush
213217
if _, ok := c.Writer.(http.Flusher); !ok {
214-
sender.SendError(c, "连接不支持SSE刷新", fmt.Errorf("no flusher"))
218+
err := sender.SendError(c, "连接不支持SSE刷新", fmt.Errorf("no flusher"))
219+
if err != nil {
220+
return
221+
}
215222
return
216223
}
217224

218225
messageId := fmt.Sprintf("msg_%s", time.Now().Format("20060102150405"))
219226

220227
resp, err := execCWRequest(c, anthropicReq, tokenInfo, true)
221228
if err != nil {
222-
sender.SendError(c, "构建请求失败", err)
229+
_ = sender.SendError(c, "构建请求失败", err)
223230
return
224231
}
225-
defer resp.Body.Close()
232+
defer func(Body io.ReadCloser) {
233+
_ = Body.Close()
234+
}(resp.Body)
226235

227236
// 立即刷新响应头
228237
c.Writer.Flush()
@@ -234,7 +243,21 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
234243
}
235244
initialEvents := eventCreator(messageId, inputContent, anthropicReq.Model)
236245
for _, event := range initialEvents {
237-
sender.SendEvent(c, event)
246+
// 更新流式事件中的input_tokens
247+
// event本身就是map[string]any类型,直接使用
248+
if message, exists := event["message"]; exists {
249+
if msgMap, ok := message.(map[string]any); ok {
250+
if usage, exists := msgMap["usage"]; exists {
251+
if usageMap, ok := usage.(map[string]any); ok {
252+
usageMap["input_tokens"] = inputTokens
253+
}
254+
}
255+
}
256+
}
257+
err := sender.SendEvent(c, event)
258+
if err != nil {
259+
return
260+
}
238261
}
239262

240263
// 创建符合AWS规范的流式解析器
@@ -272,7 +295,7 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
272295
"text": pendingText,
273296
},
274297
}
275-
sender.SendEvent(c, flush)
298+
_ = sender.SendEvent(c, flush)
276299
lastFlushedText = pendingText
277300
totalOutputChars += len(pendingText)
278301
} else {
@@ -468,7 +491,7 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
468491
"text": pendingText,
469492
},
470493
}
471-
sender.SendEvent(c, flush)
494+
_ = sender.SendEvent(c, flush)
472495
lastFlushedText = pendingText
473496
} else {
474497
logger.Debug("跳过重复/过短文本片段",
@@ -534,7 +557,7 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
534557
}
535558

536559
// 发送当前事件(若上面未 continue 掉)
537-
sender.SendEvent(c, event.Data)
560+
_ = sender.SendEvent(c, event.Data)
538561

539562
if event.Event == "content_block_delta" {
540563
content, _ := utils.GetMessageContent(event.Data)
@@ -570,8 +593,8 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
570593
"text": defaultText,
571594
},
572595
}
573-
sender.SendEvent(c, textEvent)
574-
totalOutputChars += len(defaultText)
596+
_ = sender.SendEvent(c, textEvent)
597+
totalOutputChars += tokenCalculator.CalculateOutputTokens(defaultText, false)
575598
c.Writer.Flush() // 立即刷新响应
576599
}
577600

@@ -604,8 +627,8 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
604627
"text": defaultText,
605628
},
606629
}
607-
sender.SendEvent(c, textEvent)
608-
totalOutputChars += len(defaultText)
630+
_ = sender.SendEvent(c, textEvent)
631+
totalOutputChars += tokenCalculator.CalculateOutputTokens(defaultText, false)
609632
c.Writer.Flush() // 立即刷新响应
610633
}
611634

@@ -647,9 +670,9 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
647670
logger.Bool("has_tool_result", hasToolResult))
648671
}
649672

650-
finalEvents := createAnthropicFinalEvents(totalOutputChars, stopReason)
673+
finalEvents := createAnthropicFinalEvents(tokenCalculator.CalculateOutputTokens(rawDataBuffer.String()[:min(totalOutputChars*4, rawDataBuffer.Len())], len(toolUseIdByBlockIndex) > 0), stopReason)
651674
for _, event := range finalEvents {
652-
sender.SendEvent(c, event)
675+
_ = sender.SendEvent(c, event)
653676
}
654677

655678
// 输出接收到的所有原始数据,支持回放和测试
@@ -684,6 +707,11 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
684707

685708
// createAnthropicStreamEvents 创建Anthropic流式初始事件
686709
func createAnthropicStreamEvents(messageId, inputContent, model string) []map[string]any {
710+
// 创建token计算器来计算输入tokens
711+
tokenCalculator := utils.NewTokenCalculator()
712+
// 基于输入内容估算输入tokens
713+
inputTokens := tokenCalculator.EstimateTokensFromChars(len(inputContent))
714+
687715
events := []map[string]any{
688716
{
689717
"type": "message_start",
@@ -696,7 +724,7 @@ func createAnthropicStreamEvents(messageId, inputContent, model string) []map[st
696724
"stop_reason": nil,
697725
"stop_sequence": nil,
698726
"usage": map[string]any{
699-
"input_tokens": len(inputContent),
727+
"input_tokens": inputTokens,
700728
"output_tokens": 1,
701729
},
702730
},
@@ -791,6 +819,10 @@ func containsToolResult(req types.AnthropicRequest) bool {
791819

792820
// handleNonStreamRequest 处理非流式请求
793821
func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest, tokenInfo types.TokenInfo) {
822+
// 创建token计算器
823+
tokenCalculator := utils.NewTokenCalculator()
824+
// 计算输入tokens
825+
inputTokens := tokenCalculator.CalculateInputTokens(anthropicReq)
794826
// 检测是否为包含tool_result的延续请求
795827
hasToolResult := containsToolResult(anthropicReq)
796828
if hasToolResult {
@@ -808,7 +840,9 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
808840
if err != nil {
809841
return
810842
}
811-
defer resp.Body.Close()
843+
defer func(Body io.ReadCloser) {
844+
_ = Body.Close()
845+
}(resp.Body)
812846

813847
// 读取响应体
814848
body, err := utils.ReadHTTPResponse(resp.Body)
@@ -828,8 +862,8 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
828862
"stop_sequence": nil,
829863
"type": "message",
830864
"usage": map[string]any{
831-
"input_tokens": estimateInputTokens(anthropicReq),
832-
"output_tokens": len(defaultText) / 4, // 粗略估算
865+
"input_tokens": inputTokens,
866+
"output_tokens": tokenCalculator.CalculateOutputTokens(defaultText, false),
833867
},
834868
}
835869
c.JSON(http.StatusOK, anthropicResp)
@@ -885,8 +919,8 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
885919
"stop_sequence": nil,
886920
"type": "message",
887921
"usage": map[string]any{
888-
"input_tokens": 100, // 估算值
889-
"output_tokens": len(fallbackText),
922+
"input_tokens": inputTokens,
923+
"output_tokens": tokenCalculator.CalculateOutputTokens(fallbackText, false),
890924
},
891925
}
892926
c.JSON(http.StatusOK, anthropicResp)
@@ -897,7 +931,7 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
897931
}
898932

899933
// 转换为Anthropic格式
900-
contexts := []map[string]any{}
934+
var contexts = []map[string]any{}
901935
textAgg := result.GetCompletionText()
902936

903937
// 检查文本内容是否包含XML工具标记
@@ -1110,11 +1144,6 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
11101144
return "end_turn"
11111145
}()
11121146

1113-
inputContent := ""
1114-
if len(anthropicReq.Messages) > 0 {
1115-
inputContent, _ = utils.GetMessageContent(anthropicReq.Messages[len(anthropicReq.Messages)-1].Content)
1116-
}
1117-
11181147
anthropicResp := map[string]any{
11191148
"content": contexts,
11201149
"model": anthropicReq.Model,
@@ -1123,8 +1152,8 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
11231152
"stop_sequence": nil,
11241153
"type": "message",
11251154
"usage": map[string]any{
1126-
"input_tokens": len(inputContent),
1127-
"output_tokens": len(textAgg),
1155+
"input_tokens": inputTokens,
1156+
"output_tokens": tokenCalculator.CalculateOutputTokens(textAgg, sawToolUse),
11281157
},
11291158
}
11301159

@@ -1234,27 +1263,3 @@ func extractToolNameFromId(toolUseId string) string {
12341263
}
12351264

12361265
// estimateInputTokens 估算输入token数量
1237-
func estimateInputTokens(req types.AnthropicRequest) int {
1238-
totalChars := 0
1239-
1240-
// 系统消息
1241-
for _, sysMsg := range req.System {
1242-
totalChars += len(sysMsg.Text)
1243-
}
1244-
1245-
// 所有消息
1246-
for _, msg := range req.Messages {
1247-
content, _ := utils.GetMessageContent(msg.Content)
1248-
totalChars += len(content)
1249-
}
1250-
1251-
// 工具定义
1252-
for _, tool := range req.Tools {
1253-
if tool.Name != "" {
1254-
totalChars += len(tool.Name) + 50 // 估算工具定义开销
1255-
}
1256-
}
1257-
1258-
// 粗略按照 4 字符 = 1 token 计算
1259-
return totalChars / 4
1260-
}

0 commit comments

Comments
 (0)