@@ -197,6 +197,10 @@ func isDebugMode() bool {
197197
198198// handleGenericStreamRequest 通用流式请求处理
199199func handleGenericStreamRequest (c * gin.Context , anthropicReq types.AnthropicRequest , tokenInfo types.TokenInfo , sender StreamEventSender , eventCreator func (string , string , string ) []map [string ]any ) {
200+ // 创建token计算器
201+ tokenCalculator := utils .NewTokenCalculator ()
202+ // 计算输入tokens
203+ inputTokens := tokenCalculator .CalculateInputTokens (anthropicReq )
200204 // 检测是否为包含tool_result的延续请求
201205 hasToolResult := containsToolResult (anthropicReq )
202206 if hasToolResult {
@@ -211,18 +215,23 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
211215
212216 // 确认底层Writer支持Flush
213217 if _ , ok := c .Writer .(http.Flusher ); ! ok {
214- sender .SendError (c , "连接不支持SSE刷新" , fmt .Errorf ("no flusher" ))
218+ err := sender .SendError (c , "连接不支持SSE刷新" , fmt .Errorf ("no flusher" ))
219+ if err != nil {
220+ return
221+ }
215222 return
216223 }
217224
218225 messageId := fmt .Sprintf ("msg_%s" , time .Now ().Format ("20060102150405" ))
219226
220227 resp , err := execCWRequest (c , anthropicReq , tokenInfo , true )
221228 if err != nil {
222- sender .SendError (c , "构建请求失败" , err )
229+ _ = sender .SendError (c , "构建请求失败" , err )
223230 return
224231 }
225- defer resp .Body .Close ()
232+ defer func (Body io.ReadCloser ) {
233+ _ = Body .Close ()
234+ }(resp .Body )
226235
227236 // 立即刷新响应头
228237 c .Writer .Flush ()
@@ -234,7 +243,21 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
234243 }
235244 initialEvents := eventCreator (messageId , inputContent , anthropicReq .Model )
236245 for _ , event := range initialEvents {
237- sender .SendEvent (c , event )
246+ // 更新流式事件中的input_tokens
247+ // event本身就是map[string]any类型,直接使用
248+ if message , exists := event ["message" ]; exists {
249+ if msgMap , ok := message .(map [string ]any ); ok {
250+ if usage , exists := msgMap ["usage" ]; exists {
251+ if usageMap , ok := usage .(map [string ]any ); ok {
252+ usageMap ["input_tokens" ] = inputTokens
253+ }
254+ }
255+ }
256+ }
257+ err := sender .SendEvent (c , event )
258+ if err != nil {
259+ return
260+ }
238261 }
239262
240263 // 创建符合AWS规范的流式解析器
@@ -272,7 +295,7 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
272295 "text" : pendingText ,
273296 },
274297 }
275- sender .SendEvent (c , flush )
298+ _ = sender .SendEvent (c , flush )
276299 lastFlushedText = pendingText
277300 totalOutputChars += len (pendingText )
278301 } else {
@@ -468,7 +491,7 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
468491 "text" : pendingText ,
469492 },
470493 }
471- sender .SendEvent (c , flush )
494+ _ = sender .SendEvent (c , flush )
472495 lastFlushedText = pendingText
473496 } else {
474497 logger .Debug ("跳过重复/过短文本片段" ,
@@ -534,7 +557,7 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
534557 }
535558
536559 // 发送当前事件(若上面未 continue 掉)
537- sender .SendEvent (c , event .Data )
560+ _ = sender .SendEvent (c , event .Data )
538561
539562 if event .Event == "content_block_delta" {
540563 content , _ := utils .GetMessageContent (event .Data )
@@ -570,8 +593,8 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
570593 "text" : defaultText ,
571594 },
572595 }
573- sender .SendEvent (c , textEvent )
574- totalOutputChars += len (defaultText )
596+ _ = sender .SendEvent (c , textEvent )
597+ totalOutputChars += tokenCalculator . CalculateOutputTokens (defaultText , false )
575598 c .Writer .Flush () // 立即刷新响应
576599 }
577600
@@ -604,8 +627,8 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
604627 "text" : defaultText ,
605628 },
606629 }
607- sender .SendEvent (c , textEvent )
608- totalOutputChars += len (defaultText )
630+ _ = sender .SendEvent (c , textEvent )
631+ totalOutputChars += tokenCalculator . CalculateOutputTokens (defaultText , false )
609632 c .Writer .Flush () // 立即刷新响应
610633 }
611634
@@ -647,9 +670,9 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
647670 logger .Bool ("has_tool_result" , hasToolResult ))
648671 }
649672
650- finalEvents := createAnthropicFinalEvents (totalOutputChars , stopReason )
673+ finalEvents := createAnthropicFinalEvents (tokenCalculator . CalculateOutputTokens ( rawDataBuffer . String ()[: min ( totalOutputChars * 4 , rawDataBuffer . Len ())], len ( toolUseIdByBlockIndex ) > 0 ) , stopReason )
651674 for _ , event := range finalEvents {
652- sender .SendEvent (c , event )
675+ _ = sender .SendEvent (c , event )
653676 }
654677
655678 // 输出接收到的所有原始数据,支持回放和测试
@@ -684,6 +707,11 @@ func handleGenericStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequ
684707
685708// createAnthropicStreamEvents 创建Anthropic流式初始事件
686709func createAnthropicStreamEvents (messageId , inputContent , model string ) []map [string ]any {
710+ // 创建token计算器来计算输入tokens
711+ tokenCalculator := utils .NewTokenCalculator ()
712+ // 基于输入内容估算输入tokens
713+ inputTokens := tokenCalculator .EstimateTokensFromChars (len (inputContent ))
714+
687715 events := []map [string ]any {
688716 {
689717 "type" : "message_start" ,
@@ -696,7 +724,7 @@ func createAnthropicStreamEvents(messageId, inputContent, model string) []map[st
696724 "stop_reason" : nil ,
697725 "stop_sequence" : nil ,
698726 "usage" : map [string ]any {
699- "input_tokens" : len ( inputContent ) ,
727+ "input_tokens" : inputTokens ,
700728 "output_tokens" : 1 ,
701729 },
702730 },
@@ -791,6 +819,10 @@ func containsToolResult(req types.AnthropicRequest) bool {
791819
792820// handleNonStreamRequest 处理非流式请求
793821func handleNonStreamRequest (c * gin.Context , anthropicReq types.AnthropicRequest , tokenInfo types.TokenInfo ) {
822+ // 创建token计算器
823+ tokenCalculator := utils .NewTokenCalculator ()
824+ // 计算输入tokens
825+ inputTokens := tokenCalculator .CalculateInputTokens (anthropicReq )
794826 // 检测是否为包含tool_result的延续请求
795827 hasToolResult := containsToolResult (anthropicReq )
796828 if hasToolResult {
@@ -808,7 +840,9 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
808840 if err != nil {
809841 return
810842 }
811- defer resp .Body .Close ()
843+ defer func (Body io.ReadCloser ) {
844+ _ = Body .Close ()
845+ }(resp .Body )
812846
813847 // 读取响应体
814848 body , err := utils .ReadHTTPResponse (resp .Body )
@@ -828,8 +862,8 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
828862 "stop_sequence" : nil ,
829863 "type" : "message" ,
830864 "usage" : map [string ]any {
831- "input_tokens" : estimateInputTokens ( anthropicReq ) ,
832- "output_tokens" : len (defaultText ) / 4 , // 粗略估算
865+ "input_tokens" : inputTokens ,
866+ "output_tokens" : tokenCalculator . CalculateOutputTokens (defaultText , false ),
833867 },
834868 }
835869 c .JSON (http .StatusOK , anthropicResp )
@@ -885,8 +919,8 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
885919 "stop_sequence" : nil ,
886920 "type" : "message" ,
887921 "usage" : map [string ]any {
888- "input_tokens" : 100 , // 估算值
889- "output_tokens" : len (fallbackText ),
922+ "input_tokens" : inputTokens ,
923+ "output_tokens" : tokenCalculator . CalculateOutputTokens (fallbackText , false ),
890924 },
891925 }
892926 c .JSON (http .StatusOK , anthropicResp )
@@ -897,7 +931,7 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
897931 }
898932
899933 // 转换为Anthropic格式
900- contexts : = []map [string ]any {}
934+ var contexts = []map [string ]any {}
901935 textAgg := result .GetCompletionText ()
902936
903937 // 检查文本内容是否包含XML工具标记
@@ -1110,11 +1144,6 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
11101144 return "end_turn"
11111145 }()
11121146
1113- inputContent := ""
1114- if len (anthropicReq .Messages ) > 0 {
1115- inputContent , _ = utils .GetMessageContent (anthropicReq .Messages [len (anthropicReq .Messages )- 1 ].Content )
1116- }
1117-
11181147 anthropicResp := map [string ]any {
11191148 "content" : contexts ,
11201149 "model" : anthropicReq .Model ,
@@ -1123,8 +1152,8 @@ func handleNonStreamRequest(c *gin.Context, anthropicReq types.AnthropicRequest,
11231152 "stop_sequence" : nil ,
11241153 "type" : "message" ,
11251154 "usage" : map [string ]any {
1126- "input_tokens" : len ( inputContent ) ,
1127- "output_tokens" : len (textAgg ),
1155+ "input_tokens" : inputTokens ,
1156+ "output_tokens" : tokenCalculator . CalculateOutputTokens (textAgg , sawToolUse ),
11281157 },
11291158 }
11301159
@@ -1234,27 +1263,3 @@ func extractToolNameFromId(toolUseId string) string {
12341263}
12351264
12361265// estimateInputTokens 估算输入token数量
1237- func estimateInputTokens (req types.AnthropicRequest ) int {
1238- totalChars := 0
1239-
1240- // 系统消息
1241- for _ , sysMsg := range req .System {
1242- totalChars += len (sysMsg .Text )
1243- }
1244-
1245- // 所有消息
1246- for _ , msg := range req .Messages {
1247- content , _ := utils .GetMessageContent (msg .Content )
1248- totalChars += len (content )
1249- }
1250-
1251- // 工具定义
1252- for _ , tool := range req .Tools {
1253- if tool .Name != "" {
1254- totalChars += len (tool .Name ) + 50 // 估算工具定义开销
1255- }
1256- }
1257-
1258- // 粗略按照 4 字符 = 1 token 计算
1259- return totalChars / 4
1260- }
0 commit comments