Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fixtures/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ var (
//go:embed openai/responses/blocking/builtin_tool.txtar
OaiResponsesBlockingBuiltinTool []byte

//go:embed openai/responses/blocking/cached_input_tokens.txtar
OaiResponsesBlockingCachedInputTokens []byte

//go:embed openai/responses/blocking/custom_tool.txtar
OaiResponsesBlockingCustomTool []byte

Expand Down
81 changes: 81 additions & 0 deletions fixtures/openai/responses/blocking/cached_input_tokens.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
-- request --
{
"input": "This was a large input...",
"model": "gpt-4.1",
"prompt_cache_key": "key-123",
"prompt_cache_retention": "24h",
"stream": false
}

-- non-streaming --
{
"id": "resp_0cd5d6b8310055d600696a1776b42c81a199fbb02248a8bfa0",
"object": "response",
"created_at": 1768560502,
"status": "completed",
"background": false,
"billing": {
"payer": "developer"
},
"completed_at": 1768560504,
"error": null,
"frequency_penalty": 0.0,
"incomplete_details": null,
"instructions": null,
"max_output_tokens": null,
"max_tool_calls": null,
"model": "gpt-4.1-2025-04-14",
"output": [
{
"id": "msg_0cd5d6b8310055d600696a177708b881a1bb53034def764104",
"type": "message",
"status": "completed",
"content": [
{
"type": "output_text",
"annotations": [],
"logprobs": [],
"text": "- I provide clear, accurate, and concise answers tailored to your requests.\n- I can process and summarize large volumes of information quickly.\n- I adapt my responses based on your needs and instructions for precision and relevance."
}
],
"role": "assistant"
}
],
"parallel_tool_calls": true,
"presence_penalty": 0.0,
"previous_response_id": null,
"prompt_cache_key": "key-123",
"prompt_cache_retention": "24h",
"reasoning": {
"effort": null,
"summary": null
},
"safety_identifier": null,
"service_tier": "default",
"store": true,
"temperature": 1.0,
"text": {
"format": {
"type": "text"
},
"verbosity": "medium"
},
"tool_choice": "auto",
"tools": [],
"top_logprobs": 0,
"top_p": 1.0,
"truncation": "disabled",
"usage": {
"input_tokens": 12033,
"input_tokens_details": {
"cached_tokens": 11904
},
"output_tokens": 44,
"output_tokens_details": {
"reasoning_tokens": 0
},
"total_tokens": 12077
},
"user": null,
"metadata": {}
}
27 changes: 27 additions & 0 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,33 @@ func (i *responsesInterceptionBase) parseFunctionCallJSONArgs(ctx context.Contex
return trimmed
}

func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, response *responses.Response) {
if response == nil {
i.logger.Warn(ctx, "got empty response, skipping token usage recording")
return
}

usage := response.Usage

// Keeping logic consistent with chat completions
// Input *includes* the cached tokens, so we subtract them here to reflect actual input token usage.
inputNonCacheTokens := usage.InputTokens - usage.InputTokensDetails.CachedTokens

if err := i.recorder.RecordTokenUsage(ctx, &recorder.TokenUsageRecord{
InterceptionID: i.ID().String(),
MsgID: response.ID,
Input: inputNonCacheTokens,
Output: usage.OutputTokens,
ExtraTokenTypes: map[string]int64{
"input_cached": usage.InputTokensDetails.CachedTokens,
"output_reasoning": usage.OutputTokensDetails.ReasoningTokens,
"total_tokens": usage.TotalTokens,
},
}); err != nil {
i.logger.Warn(ctx, "failed to record token usage", slog.Error(err))
}
}

// responseCopier helper struct to send original response to the client
type responseCopier struct {
buff deltaBuffer
Expand Down
71 changes: 71 additions & 0 deletions intercept/responses/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,74 @@ func TestParseJSONArgs(t *testing.T) {
})
}
}

func TestRecordTokenUsage(t *testing.T) {
t.Parallel()

id := uuid.MustParse("22222222-2222-2222-2222-222222222222")

tests := []struct {
name string
response *oairesponses.Response
expected *recorder.TokenUsageRecord
}{
{
name: "nil_response",
response: nil,
expected: nil,
},
{
name: "with_all_token_details",
response: &oairesponses.Response{
ID: "resp_full",
Usage: oairesponses.ResponseUsage{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
InputTokensDetails: oairesponses.ResponseUsageInputTokensDetails{
CachedTokens: 5,
},
OutputTokensDetails: oairesponses.ResponseUsageOutputTokensDetails{
ReasoningTokens: 5,
},
},
},
expected: &recorder.TokenUsageRecord{
InterceptionID: id.String(),
MsgID: "resp_full",
Input: 5, // 10 input - 5 cached
Output: 20,
ExtraTokenTypes: map[string]int64{
"input_cached": 5,
"output_reasoning": 5,
"total_tokens": 30,
},
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

rec := &testutil.MockRecorder{}
base := &responsesInterceptionBase{
id: id,
recorder: rec,
logger: slog.Make(),
}

base.recordTokenUsage(t.Context(), tc.response)

tokens := rec.RecordedTokenUsages()
if tc.expected == nil {
require.Empty(t, tokens)
} else {
require.Len(t, tokens, 1)
got := tokens[0]
got.CreatedAt = time.Time{} // ignore time
require.Equal(t, tc.expected, got)
}
})
}
}
3 changes: 3 additions & 0 deletions intercept/responses/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
if response != nil {
i.recordUserPrompt(ctx, response.ID)
i.recordToolUsage(ctx, response)
i.recordTokenUsage(ctx, response)
} else {
i.logger.Warn(ctx, "got empty response, skipping prompt, tool usage and token usage recording")
}

if upstreamErr != nil && !respCopy.responseReceived.Load() {
Expand Down
6 changes: 5 additions & 1 deletion intercept/responses/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r
}
}
i.recordUserPrompt(ctx, responseID)
i.recordToolUsage(ctx, completedResponse)
if completedResponse != nil {
i.recordToolUsage(ctx, completedResponse)
} else {
i.logger.Warn(ctx, "got empty response, skipping tool usage recording")
}

b, err := respCopy.readAll()
if err != nil {
Expand Down
77 changes: 77 additions & 0 deletions responses_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,23 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) {
expectModel string
expectPromptRecorded string
expectToolRecorded *recorder.ToolUsageRecord
expectTokenUsage *recorder.TokenUsageRecord
}{
{
name: "blocking_simple",
fixture: fixtures.OaiResponsesBlockingSimple,
expectModel: "gpt-4o-mini",
expectPromptRecorded: "tell me a joke",
expectTokenUsage: &recorder.TokenUsageRecord{
MsgID: "resp_0388c79043df3e3400695f9f83cd6481959062cec6830d8d51",
Input: 11,
Output: 18,
ExtraTokenTypes: map[string]int64{
"input_cached": 0,
"output_reasoning": 0,
"total_tokens": 29,
},
},
},
{
name: "blocking_builtin_tool",
Expand All @@ -56,6 +67,32 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) {
Args: map[string]any{"a": float64(3), "b": float64(5)},
Injected: false,
},
expectTokenUsage: &recorder.TokenUsageRecord{
MsgID: "resp_0da6045a8b68fa5200695fa23dcc2c81a19c849f627abf8a31",
Input: 58,
Output: 18,
ExtraTokenTypes: map[string]int64{
"input_cached": 0,
"output_reasoning": 0,
"total_tokens": 76,
},
},
},
{
name: "blocking_cached_input_tokens",
fixture: fixtures.OaiResponsesBlockingCachedInputTokens,
expectModel: "gpt-4.1",
expectPromptRecorded: "This was a large input...",
expectTokenUsage: &recorder.TokenUsageRecord{
MsgID: "resp_0cd5d6b8310055d600696a1776b42c81a199fbb02248a8bfa0",
Input: 129, // 12033 input - 11904 cached
Output: 44,
ExtraTokenTypes: map[string]int64{
"input_cached": 11904,
"output_reasoning": 0,
"total_tokens": 12077,
},
},
},
{
name: "blocking_custom_tool",
Expand All @@ -68,18 +105,48 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) {
Args: "print(\"hello world\")",
Injected: false,
},
expectTokenUsage: &recorder.TokenUsageRecord{
MsgID: "resp_09c614364030cdf000696942589da081a0af07f5859acb7308",
Input: 64,
Output: 148,
ExtraTokenTypes: map[string]int64{
"input_cached": 0,
"output_reasoning": 128,
"total_tokens": 212,
},
},
},
{
name: "blocking_conversation",
fixture: fixtures.OaiResponsesBlockingConversation,
expectModel: "gpt-4o-mini",
expectPromptRecorded: "explain why this is funny.",
expectTokenUsage: &recorder.TokenUsageRecord{
MsgID: "resp_0c9f1f0524a858fa00695fa15fc5a081958f4304aafd3bdec2",
Input: 48,
Output: 116,
ExtraTokenTypes: map[string]int64{
"input_cached": 0,
"output_reasoning": 0,
"total_tokens": 164,
},
},
},
{
name: "blocking_prev_response_id",
fixture: fixtures.OaiResponsesBlockingPrevResponseID,
expectModel: "gpt-4o-mini",
expectPromptRecorded: "explain why this is funny.",
expectTokenUsage: &recorder.TokenUsageRecord{
MsgID: "resp_0388c79043df3e3400695f9f86cfa08195af1f015c60117a83",
Input: 43,
Output: 129,
ExtraTokenTypes: map[string]int64{
"input_cached": 0,
"output_reasoning": 0,
"total_tokens": 172,
},
},
},
{
name: "streaming_simple",
Expand Down Expand Up @@ -226,6 +293,16 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) {
} else {
require.Empty(t, recordedTools)
}

recordedTokens := mockRecorder.RecordedTokenUsages()
if tc.expectTokenUsage != nil {
require.Len(t, recordedTokens, 1)
recordedTokens[0].InterceptionID = tc.expectTokenUsage.InterceptionID // ignore interception id
recordedTokens[0].CreatedAt = tc.expectTokenUsage.CreatedAt // ignore time
require.Equal(t, tc.expectTokenUsage, recordedTokens[0])
} else {
require.Empty(t, recordedTokens)
}
})
}
}
Expand Down