diff --git a/fixtures/fixtures.go b/fixtures/fixtures.go index 21bff9d..af1e454 100644 --- a/fixtures/fixtures.go +++ b/fixtures/fixtures.go @@ -54,6 +54,9 @@ var ( //go:embed openai/responses/blocking/builtin_tool.txtar OaiResponsesBlockingBuiltinTool []byte + //go:embed openai/responses/blocking/cached_input_tokens.txtar + OaiResponsesBlockingCachedInputTokens []byte + //go:embed openai/responses/blocking/custom_tool.txtar OaiResponsesBlockingCustomTool []byte diff --git a/fixtures/openai/responses/blocking/cached_input_tokens.txtar b/fixtures/openai/responses/blocking/cached_input_tokens.txtar new file mode 100644 index 0000000..41a6d7c --- /dev/null +++ b/fixtures/openai/responses/blocking/cached_input_tokens.txtar @@ -0,0 +1,81 @@ +-- request -- +{ + "input": "This was a large input...", + "model": "gpt-4.1", + "prompt_cache_key": "key-123", + "prompt_cache_retention": "24h", + "stream": false +} + +-- non-streaming -- +{ + "id": "resp_0cd5d6b8310055d600696a1776b42c81a199fbb02248a8bfa0", + "object": "response", + "created_at": 1768560502, + "status": "completed", + "background": false, + "billing": { + "payer": "developer" + }, + "completed_at": 1768560504, + "error": null, + "frequency_penalty": 0.0, + "incomplete_details": null, + "instructions": null, + "max_output_tokens": null, + "max_tool_calls": null, + "model": "gpt-4.1-2025-04-14", + "output": [ + { + "id": "msg_0cd5d6b8310055d600696a177708b881a1bb53034def764104", + "type": "message", + "status": "completed", + "content": [ + { + "type": "output_text", + "annotations": [], + "logprobs": [], + "text": "- I provide clear, accurate, and concise answers tailored to your requests.\n- I can process and summarize large volumes of information quickly.\n- I adapt my responses based on your needs and instructions for precision and relevance." + } + ], + "role": "assistant" + } + ], + "parallel_tool_calls": true, + "presence_penalty": 0.0, + "previous_response_id": null, + "prompt_cache_key": "key-123", + "prompt_cache_retention": "24h", + "reasoning": { + "effort": null, + "summary": null + }, + "safety_identifier": null, + "service_tier": "default", + "store": true, + "temperature": 1.0, + "text": { + "format": { + "type": "text" + }, + "verbosity": "medium" + }, + "tool_choice": "auto", + "tools": [], + "top_logprobs": 0, + "top_p": 1.0, + "truncation": "disabled", + "usage": { + "input_tokens": 12033, + "input_tokens_details": { + "cached_tokens": 11904 + }, + "output_tokens": 44, + "output_tokens_details": { + "reasoning_tokens": 0 + }, + "total_tokens": 12077 + }, + "user": null, + "metadata": {} +} diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 99b4020..f369dc1 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -265,6 +265,33 @@ func (i *responsesInterceptionBase) parseFunctionCallJSONArgs(ctx context.Contex return trimmed } +func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, response *responses.Response) { + if response == nil { + i.logger.Warn(ctx, "got empty response, skipping token usage recording") + return + } + + usage := response.Usage + + // Keeping logic consistent with chat completions + // Input *includes* the cached tokens, so we subtract them here to reflect actual input token usage. + inputNonCacheTokens := usage.InputTokens - usage.InputTokensDetails.CachedTokens + + if err := i.recorder.RecordTokenUsage(ctx, &recorder.TokenUsageRecord{ + InterceptionID: i.ID().String(), + MsgID: response.ID, + Input: inputNonCacheTokens, + Output: usage.OutputTokens, + ExtraTokenTypes: map[string]int64{ + "input_cached": usage.InputTokensDetails.CachedTokens, + "output_reasoning": usage.OutputTokensDetails.ReasoningTokens, + "total_tokens": usage.TotalTokens, + }, + }); err != nil { + i.logger.Warn(ctx, "failed to record token usage", slog.Error(err)) + } +} + // responseCopier helper struct to send original response to the client type responseCopier struct { buff deltaBuffer diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index 9f723dd..0374629 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -379,3 +379,74 @@ func TestParseJSONArgs(t *testing.T) { }) } } + +func TestRecordTokenUsage(t *testing.T) { + t.Parallel() + + id := uuid.MustParse("22222222-2222-2222-2222-222222222222") + + tests := []struct { + name string + response *oairesponses.Response + expected *recorder.TokenUsageRecord + }{ + { + name: "nil_response", + response: nil, + expected: nil, + }, + { + name: "with_all_token_details", + response: &oairesponses.Response{ + ID: "resp_full", + Usage: oairesponses.ResponseUsage{ + InputTokens: 10, + OutputTokens: 20, + TotalTokens: 30, + InputTokensDetails: oairesponses.ResponseUsageInputTokensDetails{ + CachedTokens: 5, + }, + OutputTokensDetails: oairesponses.ResponseUsageOutputTokensDetails{ + ReasoningTokens: 5, + }, + }, + }, + expected: &recorder.TokenUsageRecord{ + InterceptionID: id.String(), + MsgID: "resp_full", + Input: 5, // 10 input - 5 cached + Output: 20, + ExtraTokenTypes: map[string]int64{ + "input_cached": 5, + "output_reasoning": 5, + "total_tokens": 30, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rec := &testutil.MockRecorder{} + base := &responsesInterceptionBase{ + id: id, + recorder: rec, + logger: slog.Make(), + } + + base.recordTokenUsage(t.Context(), tc.response) + + tokens := rec.RecordedTokenUsages() + if tc.expected == nil { + require.Empty(t, tokens) + } else { + require.Len(t, tokens, 1) + got := tokens[0] + got.CreatedAt = time.Time{} // ignore time + require.Equal(t, tc.expected, got) + } + }) + } +} diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 726701d..6161d6f 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -56,6 +56,9 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * if response != nil { i.recordUserPrompt(ctx, response.ID) i.recordToolUsage(ctx, response) + i.recordTokenUsage(ctx, response) + } else { + i.logger.Warn(ctx, "got empty response, skipping prompt, tool usage and token usage recording") } if upstreamErr != nil && !respCopy.responseReceived.Load() { diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 269360f..8eedbb9 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -116,7 +116,11 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r } } i.recordUserPrompt(ctx, responseID) - i.recordToolUsage(ctx, completedResponse) + if completedResponse != nil { + i.recordToolUsage(ctx, completedResponse) + } else { + i.logger.Warn(ctx, "got empty response, skipping tool usage recording") + } b, err := respCopy.readAll() if err != nil { diff --git a/responses_integration_test.go b/responses_integration_test.go index 5c12576..51f620d 100644 --- a/responses_integration_test.go +++ b/responses_integration_test.go @@ -38,12 +38,23 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectModel string expectPromptRecorded string expectToolRecorded *recorder.ToolUsageRecord + expectTokenUsage *recorder.TokenUsageRecord }{ { name: "blocking_simple", fixture: fixtures.OaiResponsesBlockingSimple, expectModel: "gpt-4o-mini", expectPromptRecorded: "tell me a joke", + expectTokenUsage: &recorder.TokenUsageRecord{ + MsgID: "resp_0388c79043df3e3400695f9f83cd6481959062cec6830d8d51", + Input: 11, + Output: 18, + ExtraTokenTypes: map[string]int64{ + "input_cached": 0, + "output_reasoning": 0, + "total_tokens": 29, + }, + }, }, { name: "blocking_builtin_tool", @@ -56,6 +67,32 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { Args: map[string]any{"a": float64(3), "b": float64(5)}, Injected: false, }, + expectTokenUsage: &recorder.TokenUsageRecord{ + MsgID: "resp_0da6045a8b68fa5200695fa23dcc2c81a19c849f627abf8a31", + Input: 58, + Output: 18, + ExtraTokenTypes: map[string]int64{ + "input_cached": 0, + "output_reasoning": 0, + "total_tokens": 76, + }, + }, + }, + { + name: "blocking_cached_input_tokens", + fixture: fixtures.OaiResponsesBlockingCachedInputTokens, + expectModel: "gpt-4.1", + expectPromptRecorded: "This was a large input...", + expectTokenUsage: &recorder.TokenUsageRecord{ + MsgID: "resp_0cd5d6b8310055d600696a1776b42c81a199fbb02248a8bfa0", + Input: 129, // 12033 input - 11904 cached + Output: 44, + ExtraTokenTypes: map[string]int64{ + "input_cached": 11904, + "output_reasoning": 0, + "total_tokens": 12077, + }, + }, }, { name: "blocking_custom_tool", @@ -68,18 +105,48 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { Args: "print(\"hello world\")", Injected: false, }, + expectTokenUsage: &recorder.TokenUsageRecord{ + MsgID: "resp_09c614364030cdf000696942589da081a0af07f5859acb7308", + Input: 64, + Output: 148, + ExtraTokenTypes: map[string]int64{ + "input_cached": 0, + "output_reasoning": 128, + "total_tokens": 212, + }, + }, }, { name: "blocking_conversation", fixture: fixtures.OaiResponsesBlockingConversation, expectModel: "gpt-4o-mini", expectPromptRecorded: "explain why this is funny.", + expectTokenUsage: &recorder.TokenUsageRecord{ + MsgID: "resp_0c9f1f0524a858fa00695fa15fc5a081958f4304aafd3bdec2", + Input: 48, + Output: 116, + ExtraTokenTypes: map[string]int64{ + "input_cached": 0, + "output_reasoning": 0, + "total_tokens": 164, + }, + }, }, { name: "blocking_prev_response_id", fixture: fixtures.OaiResponsesBlockingPrevResponseID, expectModel: "gpt-4o-mini", expectPromptRecorded: "explain why this is funny.", + expectTokenUsage: &recorder.TokenUsageRecord{ + MsgID: "resp_0388c79043df3e3400695f9f86cfa08195af1f015c60117a83", + Input: 43, + Output: 129, + ExtraTokenTypes: map[string]int64{ + "input_cached": 0, + "output_reasoning": 0, + "total_tokens": 172, + }, + }, }, { name: "streaming_simple", @@ -226,6 +293,16 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { } else { require.Empty(t, recordedTools) } + + recordedTokens := mockRecorder.RecordedTokenUsages() + if tc.expectTokenUsage != nil { + require.Len(t, recordedTokens, 1) + recordedTokens[0].InterceptionID = tc.expectTokenUsage.InterceptionID // ignore interception id + recordedTokens[0].CreatedAt = tc.expectTokenUsage.CreatedAt // ignore time + require.Equal(t, tc.expectTokenUsage, recordedTokens[0]) + } else { + require.Empty(t, recordedTokens) + } }) } }