Skip to content

Commit f27ed16

Browse files
committed
fix: extract and estimate tokens
1 parent 0189c7b commit f27ed16

File tree

5 files changed

+61
-16
lines changed

5 files changed

+61
-16
lines changed

internal/proxy/proxy.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,13 @@ func (p *Proxy) handleProxy(w http.ResponseWriter, r *http.Request) {
353353
isStreaming := contentType == "text/event-stream" || (claudeReq.Stream && strings.Contains(contentType, "text/event-stream"))
354354

355355
if resp.StatusCode == http.StatusOK && isStreaming {
356-
inputTokens, outputTokens, _ := p.handleStreamingResponse(w, resp, endpoint, trans, transformerName, thinkingEnabled)
356+
inputTokens, outputTokens, outputText := p.handleStreamingResponse(w, resp, endpoint, trans, transformerName, thinkingEnabled)
357+
358+
// Fallback: estimate tokens when usage is 0
359+
if inputTokens == 0 || outputTokens == 0 {
360+
inputTokens, outputTokens = p.estimateTokens(bodyBytes, outputText, inputTokens, outputTokens, endpoint.Name)
361+
}
362+
357363
p.stats.RecordTokens(endpoint.Name, inputTokens, outputTokens)
358364
p.markRequestInactive(endpoint.Name)
359365
logger.Debug("[%s] Request completed successfully (streaming)", endpoint.Name)

internal/proxy/streaming.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
)
1818

1919
// handleStreamingResponse processes streaming SSE responses
20-
func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, endpoint config.Endpoint, trans transformer.Transformer, transformerName string, thinkingEnabled bool) (int, int, error) {
20+
func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, endpoint config.Endpoint, trans transformer.Transformer, transformerName string, thinkingEnabled bool) (int, int, string) {
2121
// Copy response headers except Content-Length (streaming response length is unknown)
2222
for key, values := range resp.Header {
2323
if key == "Content-Length" {
@@ -33,7 +33,7 @@ func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Respon
3333
if !ok {
3434
logger.Error("[%s] ResponseWriter does not support flushing", endpoint.Name)
3535
resp.Body.Close()
36-
return 0, 0, nil
36+
return 0, 0, ""
3737
}
3838

3939
var streamCtx *transformer.StreamContext
@@ -110,7 +110,7 @@ func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Respon
110110
}
111111

112112
resp.Body.Close()
113-
return inputTokens, outputTokens, nil
113+
return inputTokens, outputTokens, outputText.String()
114114
}
115115

116116
// transformStreamEvent transforms a single SSE event
@@ -134,7 +134,7 @@ func (p *Proxy) extractTokensFromEvent(eventData []byte, inputTokens, outputToke
134134
continue
135135
}
136136

137-
jsonData := strings.TrimPrefix(line, "data:")
137+
jsonData := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
138138
var event map[string]interface{}
139139
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
140140
continue
@@ -168,7 +168,7 @@ func (p *Proxy) extractTextFromEvent(transformedEvent []byte, outputText *string
168168
continue
169169
}
170170

171-
jsonData := strings.TrimPrefix(line, "data:")
171+
jsonData := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
172172
var event map[string]interface{}
173173
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
174174
continue

internal/proxy/utils.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"strings"
77

88
"github.com/lich0821/ccNexus/internal/logger"
9+
"github.com/lich0821/ccNexus/internal/tokencount"
910
)
1011

1112
// normalizeAPIUrl ensures the API URL has a protocol prefix
@@ -88,3 +89,21 @@ func cleanIncompleteToolCalls(bodyBytes []byte) ([]byte, error) {
8889
req["messages"] = messages
8990
return json.Marshal(req)
9091
}
92+
93+
// estimateTokens estimates tokens when API doesn't provide usage
94+
func (p *Proxy) estimateTokens(bodyBytes []byte, outputText string, inputTokens, outputTokens int, endpointName string) (int, int) {
95+
if inputTokens == 0 {
96+
var req tokencount.CountTokensRequest
97+
if json.Unmarshal(bodyBytes, &req) == nil {
98+
inputTokens = tokencount.EstimateInputTokens(&req)
99+
logger.Debug("[%s] Estimated input tokens: %d", endpointName, inputTokens)
100+
}
101+
}
102+
103+
if outputTokens == 0 && outputText != "" {
104+
outputTokens = tokencount.EstimateOutputTokens(outputText)
105+
logger.Debug("[%s] Estimated output tokens: %d", endpointName, outputTokens)
106+
}
107+
108+
return inputTokens, outputTokens
109+
}

internal/transformer/gemini/gemini.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,12 @@ func (t *GeminiTransformer) transformStreamingResponse(geminiStream []byte, ctx
410410
continue
411411
}
412412

413+
// Extract usage metadata from chunk
414+
if chunk.UsageMetadata != nil {
415+
ctx.InputTokens = chunk.UsageMetadata.PromptTokenCount
416+
ctx.OutputTokens = chunk.UsageMetadata.CandidatesTokenCount
417+
}
418+
413419
// Send message_start on first chunk
414420
if !ctx.MessageStartSent {
415421
ctx.MessageID = fmt.Sprintf("msg_%d", 0)
@@ -425,8 +431,8 @@ func (t *GeminiTransformer) transformStreamingResponse(geminiStream []byte, ctx
425431
"content": []interface{}{},
426432
"model": ctx.ModelName,
427433
"usage": map[string]interface{}{
428-
"input_tokens": 0,
429-
"output_tokens": 0,
434+
"input_tokens": ctx.InputTokens,
435+
"output_tokens": ctx.OutputTokens,
430436
},
431437
},
432438
}
@@ -610,12 +616,6 @@ func (t *GeminiTransformer) transformStreamingResponse(geminiStream []byte, ctx
610616
result.WriteString("\n")
611617
}
612618
}
613-
614-
// Update usage metadata
615-
if chunk.UsageMetadata != nil {
616-
ctx.InputTokens = chunk.UsageMetadata.PromptTokenCount
617-
ctx.OutputTokens = chunk.UsageMetadata.CandidatesTokenCount
618-
}
619619
}
620620

621621
if err := scanner.Err(); err != nil {

internal/transformer/openai/event_handler.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ func (h *MessageStartHandler) Handle(event *SSEEvent, state *StreamState) ([]*SS
2424
return []*SSEEvent{event}, nil
2525
}
2626

27+
// Extract usage from first event
28+
if usage, ok := event.Data["usage"].(map[string]interface{}); ok {
29+
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
30+
state.InputTokens = int(promptTokens)
31+
}
32+
if completionTokens, ok := usage["completion_tokens"].(float64); ok {
33+
state.OutputTokens = int(completionTokens)
34+
}
35+
}
36+
2737
if !state.MessageStarted {
2838
if id, ok := event.Data["id"].(string); ok {
2939
state.MessageID = id
@@ -46,8 +56,8 @@ func (h *MessageStartHandler) Handle(event *SSEEvent, state *StreamState) ([]*SS
4656
"content": []interface{}{},
4757
"model": state.ModelName,
4858
"usage": map[string]interface{}{
49-
"input_tokens": 0,
50-
"output_tokens": 0,
59+
"input_tokens": state.InputTokens,
60+
"output_tokens": state.OutputTokens,
5161
},
5262
},
5363
},
@@ -65,6 +75,16 @@ func (h *ContentDeltaHandler) Handle(event *SSEEvent, state *StreamState) ([]*SS
6575
return nil, nil
6676
}
6777

78+
// Extract usage from every event
79+
if usage, ok := event.Data["usage"].(map[string]interface{}); ok {
80+
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
81+
state.InputTokens = int(promptTokens)
82+
}
83+
if completionTokens, ok := usage["completion_tokens"].(float64); ok {
84+
state.OutputTokens = int(completionTokens)
85+
}
86+
}
87+
6888
choices, ok := event.Data["choices"].([]interface{})
6989
if !ok || len(choices) == 0 {
7090
return nil, nil

0 commit comments

Comments
 (0)