Skip to content

Commit 9c0fd0c

Browse files
authored
fix: correct openai input tokens value reported (#13)
Input tokens include cached tokens, so we must subtract them before reporting Signed-off-by: Danny Kopping <[email protected]>
1 parent 068c87e commit 9c0fd0c

File tree

4 files changed

+15
-4
lines changed

4 files changed

+15
-4
lines changed

bridge_integration_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,11 +736,13 @@ func TestOpenAIInjectedTools(t *testing.T) {
736736
require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned.
737737

738738
// Check the token usage from the client's perspective.
739-
assert.EqualValues(t, 9911, message.Usage.PromptTokens)
739+
// This *should* work but the openai SDK doesn't accumulate the prompt token details :(.
740+
// See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147.
741+
// assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens)
740742
assert.EqualValues(t, 105, message.Usage.CompletionTokens)
741743

742744
// Ensure tokens used during injected tool invocation are accounted for.
743-
require.EqualValues(t, 9911, calculateTotalInputTokens(recorderClient.tokenUsages))
745+
require.EqualValues(t, 5047, calculateTotalInputTokens(recorderClient.tokenUsages))
744746
require.EqualValues(t, 105, calculateTotalOutputTokens(recorderClient.tokenUsages))
745747
})
746748
}

intercept_openai_chat_blocking.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
8181
_ = i.recorder.RecordTokenUsage(ctx, &TokenUsageRecord{
8282
InterceptionID: i.ID().String(),
8383
MsgID: completion.ID,
84-
Input: lastUsage.PromptTokens,
84+
Input: calculateActualInputTokenUsage(lastUsage),
8585
Output: lastUsage.CompletionTokens,
8686
Metadata: Metadata{
8787
"prompt_audio": lastUsage.PromptTokensDetails.AudioTokens,

intercept_openai_chat_streaming.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
165165
_ = i.recorder.RecordTokenUsage(streamCtx, &TokenUsageRecord{
166166
InterceptionID: i.ID().String(),
167167
MsgID: processor.getMsgID(),
168-
Input: lastUsage.PromptTokens,
168+
Input: calculateActualInputTokenUsage(lastUsage),
169169
Output: lastUsage.CompletionTokens,
170170
Metadata: Metadata{
171171
"prompt_audio": lastUsage.PromptTokensDetails.AudioTokens,

openai.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage {
9696
}
9797
}
9898

99+
// calculateActualInputTokenUsage accounts for cached tokens which are included in [openai.CompletionUsage].PromptTokens.
100+
func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
101+
// Input *includes* the cached tokens, so we subtract them here to reflect actual input token usage.
102+
// The original value can be reconstructed by referencing the "prompt_cached" field in metadata.
103+
// See https://platform.openai.com/docs/api-reference/usage/completions_object#usage/completions_object-input_tokens.
104+
return in.PromptTokens /* The aggregated number of text input tokens used, including cached tokens. */ -
105+
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */
106+
}
107+
99108
func getOpenAIErrorResponse(err error) *OpenAIErrorResponse {
100109
var apierr *openai.Error
101110
if !errors.As(err, &apierr) {

0 commit comments

Comments
 (0)