diff --git a/fixtures/fixtures.go b/fixtures/fixtures.go index 905fbf9..ebd6137 100644 --- a/fixtures/fixtures.go +++ b/fixtures/fixtures.go @@ -54,6 +54,9 @@ var ( //go:embed openai/responses/blocking/builtin_tool.txtar OaiResponsesBlockingBuiltinTool []byte + //go:embed openai/responses/blocking/custom_tool.txtar + OaiResponsesBlockingCustomTool []byte + //go:embed openai/responses/blocking/conversation.txtar OaiResponsesBlockingConversation []byte diff --git a/fixtures/openai/responses/blocking/custom_tool.txtar b/fixtures/openai/responses/blocking/custom_tool.txtar new file mode 100644 index 0000000..a196593 --- /dev/null +++ b/fixtures/openai/responses/blocking/custom_tool.txtar @@ -0,0 +1,93 @@ +-- request -- +{ + "input": "Use the code_exec tool to print hello world to the console.", + "model": "gpt-5", + "tools": [ + { + "type": "custom", + "name": "code_exec", + "description": "Executes arbitrary Python code." + } + ] +} + +-- non-streaming -- +{ + "id": "resp_09c614364030cdf000696942589da081a0af07f5859acb7308", + "object": "response", + "created_at": 1768505944, + "status": "completed", + "background": false, + "billing": { + "payer": "developer" + }, + "completed_at": 1768505948, + "error": null, + "frequency_penalty": 0.0, + "incomplete_details": null, + "instructions": null, + "max_output_tokens": null, + "max_tool_calls": null, + "model": "gpt-5-2025-08-07", + "output": [ + { + "id": "rs_09c614364030cdf00069694258e45881a0b8d5f198cde47d58", + "type": "reasoning", + "summary": [] + }, + { + "id": "ctc_09c614364030cdf0006969425bf33481a09cc0f9522af2d980", + "type": "custom_tool_call", + "status": "completed", + "call_id": "call_haf8njtwrVZ1754Gm6fjAtuA", + "input": "print(\"hello world\")", + "name": "code_exec" + } + ], + "parallel_tool_calls": true, + "presence_penalty": 0.0, + "previous_response_id": null, + "prompt_cache_key": null, + "prompt_cache_retention": null, + "reasoning": { + "effort": "medium", + "summary": null + }, + "safety_identifier": null, + "service_tier": "default", + "store": true, + "temperature": 1.0, + "text": { + "format": { + "type": "text" + }, + "verbosity": "medium" + }, + "tool_choice": "auto", + "tools": [ + { + "type": "custom", + "description": "Executes arbitrary Python code.", + "format": { + "type": "text" + }, + "name": "code_exec" + } + ], + "top_logprobs": 0, + "top_p": 1.0, + "truncation": "disabled", + "usage": { + "input_tokens": 64, + "input_tokens_details": { + "cached_tokens": 0 + }, + "output_tokens": 148, + "output_tokens_details": { + "reasoning_tokens": 128 + }, + "total_tokens": 212 + }, + "user": null, + "metadata": {} +} \ No newline at end of file diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 3aeb047..99b4020 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -24,6 +24,7 @@ import ( "github.com/google/uuid" "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/responses" + oaiconst "github.com/openai/openai-go/v3/shared/constant" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.opentelemetry.io/otel/attribute" @@ -220,6 +221,50 @@ func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, respon } } +func (i *responsesInterceptionBase) recordToolUsage(ctx context.Context, response *responses.Response) { + if response == nil { + i.logger.Warn(ctx, "got empty response, skipping tool usage recording") + return + } + + for _, item := range response.Output { + var args recorder.ToolArgs + + // recording other function types to be considered: https://github.com/coder/aibridge/issues/121 + switch item.Type { + case string(oaiconst.ValueOf[oaiconst.FunctionCall]()): + args = i.parseFunctionCallJSONArgs(ctx, item.Arguments) + case string(oaiconst.ValueOf[oaiconst.CustomToolCall]()): + args = item.Input + default: + continue + } + + if err := i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ + InterceptionID: i.ID().String(), + MsgID: response.ID, + Tool: item.Name, + Args: args, + Injected: false, + }); err != nil { + i.logger.Warn(ctx, "failed to record tool usage", slog.Error(err), slog.F("tool", item.Name)) + } + } +} + +func (i *responsesInterceptionBase) parseFunctionCallJSONArgs(ctx context.Context, raw string) recorder.ToolArgs { + trimmed := strings.TrimSpace(raw) + if trimmed != "" { + var args recorder.ToolArgs + if err := json.Unmarshal([]byte(trimmed), &args); err != nil { + i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err)) + } else { + return args + } + } + return trimmed +} + // responseCopier helper struct to send original response to the client type responseCopier struct { buff deltaBuffer diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index 931f76d..9f723dd 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -2,11 +2,14 @@ package responses import ( "testing" + "time" "cdr.dev/slog/v3" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/internal/testutil" + "github.com/coder/aibridge/recorder" "github.com/google/uuid" + oairesponses "github.com/openai/openai-go/v3/responses" "github.com/stretchr/testify/require" ) @@ -194,3 +197,185 @@ func TestRecordPrompt(t *testing.T) { }) } } + +func TestRecordToolUsage(t *testing.T) { + t.Parallel() + + id := uuid.MustParse("11111111-1111-1111-1111-111111111111") + + tests := []struct { + name string + response *oairesponses.Response + expected []*recorder.ToolUsageRecord + }{ + { + name: "nil_response", + response: nil, + expected: nil, + }, + { + name: "empty_output", + response: &oairesponses.Response{ + ID: "resp_123", + }, + expected: nil, + }, + { + name: "empty_tool_args", + response: &oairesponses.Response{ + ID: "resp_456", + Output: []oairesponses.ResponseOutputItemUnion{ + { + Type: "function_call", + Name: "get_weather", + Arguments: "", + }, + }, + }, + expected: []*recorder.ToolUsageRecord{ + { + InterceptionID: id.String(), + MsgID: "resp_456", + Tool: "get_weather", + Args: "", + Injected: false, + }, + }, + }, + { + name: "multiple_tool_calls", + response: &oairesponses.Response{ + ID: "resp_789", + Output: []oairesponses.ResponseOutputItemUnion{ + { + Type: "function_call", + Name: "get_weather", + Arguments: `{"location": "NYC"}`, + }, + { + Type: "function_call", + Name: "bad_json_args", + Arguments: `{"bad": args`, + }, + { + Type: "message", + ID: "msg_1", + Role: "assistant", + }, + { + Type: "custom_tool_call", + Name: "search", + Input: `{\"query\": \"test\"}`, + }, + { + Type: "function_call", + Name: "calculate", + Arguments: `{"a": 1, "b": 2}`, + }, + }, + }, + expected: []*recorder.ToolUsageRecord{ + { + InterceptionID: id.String(), + MsgID: "resp_789", + Tool: "get_weather", + Args: map[string]any{"location": "NYC"}, + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_789", + Tool: "bad_json_args", + Args: `{"bad": args`, + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_789", + Tool: "search", + Args: `{\"query\": \"test\"}`, + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_789", + Tool: "calculate", + Args: map[string]any{"a": float64(1), "b": float64(2)}, + Injected: false, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rec := &testutil.MockRecorder{} + base := &responsesInterceptionBase{ + id: id, + recorder: rec, + logger: slog.Make(), + } + + base.recordToolUsage(t.Context(), tc.response) + + tools := rec.RecordedToolUsages() + require.Len(t, tools, len(tc.expected)) + for i, got := range tools { + got.CreatedAt = time.Time{} + require.Equal(t, tc.expected[i], got) + } + }) + } +} + +func TestParseJSONArgs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw string + expected recorder.ToolArgs + }{ + { + name: "empty_string", + raw: "", + expected: "", + }, + { + name: "whitespace_only", + raw: " \t\n ", + expected: "", + }, + { + name: "invalid_json", + raw: "{not valid json}", + expected: "{not valid json}", + }, + { + name: "nested_object_with_trailing_spaces", + raw: ` {"user": {"name": "alice", "settings": {"theme": "dark", "notifications": true}}, "count": 42} `, + expected: map[string]any{ + "user": map[string]any{ + "name": "alice", + "settings": map[string]any{ + "theme": "dark", + "notifications": true, + }, + }, + "count": float64(42), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + base := &responsesInterceptionBase{} + result := base.parseFunctionCallJSONArgs(t.Context(), tc.raw) + require.Equal(t, tc.expected, result) + }) + } +} diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 56792a4..726701d 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -55,6 +55,7 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * // response could be nil eg. fixtures/openai/responses/blocking/wrong_response_format.txtar if response != nil { i.recordUserPrompt(ctx, response.ID) + i.recordToolUsage(ctx, response) } if upstreamErr != nil && !respCopy.responseReceived.Load() { diff --git a/responses_integration_test.go b/responses_integration_test.go index f4ba1c6..57421f0 100644 --- a/responses_integration_test.go +++ b/responses_integration_test.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" @@ -33,80 +35,113 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { name string fixture []byte streaming bool + expectModel string expectPromptRecorded string + expectToolRecorded *recorder.ToolUsageRecord }{ { name: "blocking_simple", fixture: fixtures.OaiResponsesBlockingSimple, + expectModel: "gpt-4o-mini", expectPromptRecorded: "tell me a joke", }, { name: "blocking_builtin_tool", fixture: fixtures.OaiResponsesBlockingBuiltinTool, + expectModel: "gpt-4.1", expectPromptRecorded: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", + expectToolRecorded: &recorder.ToolUsageRecord{ + MsgID: "resp_0da6045a8b68fa5200695fa23dcc2c81a19c849f627abf8a31", + Tool: "add", + Args: map[string]any{"a": float64(3), "b": float64(5)}, + Injected: false, + }, + }, + { + name: "blocking_custom_tool", + fixture: fixtures.OaiResponsesBlockingCustomTool, + expectModel: "gpt-5", + expectPromptRecorded: "Use the code_exec tool to print hello world to the console.", + expectToolRecorded: &recorder.ToolUsageRecord{ + MsgID: "resp_09c614364030cdf000696942589da081a0af07f5859acb7308", + Tool: "code_exec", + Args: "print(\"hello world\")", + Injected: false, + }, }, { name: "blocking_conversation", fixture: fixtures.OaiResponsesBlockingConversation, + expectModel: "gpt-4o-mini", expectPromptRecorded: "explain why this is funny.", }, { name: "blocking_prev_response_id", fixture: fixtures.OaiResponsesBlockingPrevResponseID, + expectModel: "gpt-4o-mini", expectPromptRecorded: "explain why this is funny.", }, { name: "streaming_simple", fixture: fixtures.OaiResponsesStreamingSimple, streaming: true, + expectModel: "gpt-4o-mini", expectPromptRecorded: "tell me a joke", }, { name: "streaming_codex", fixture: fixtures.OaiResponsesStreamingCodex, streaming: true, + expectModel: "gpt-5-codex", expectPromptRecorded: "hello", }, { name: "streaming_builtin_tool", fixture: fixtures.OaiResponsesStreamingBuiltinTool, streaming: true, + expectModel: "gpt-4.1", expectPromptRecorded: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", }, { name: "streaming_conversation", fixture: fixtures.OaiResponsesStreamingConversation, streaming: true, + expectModel: "gpt-4o-mini", expectPromptRecorded: "explain why this is funny.", }, { name: "streaming_prev_response_id", fixture: fixtures.OaiResponsesStreamingPrevResponseID, streaming: true, + expectModel: "gpt-4o-mini", expectPromptRecorded: "explain why this is funny.", }, { name: "stream_error", fixture: fixtures.OaiResponsesStreamingStreamError, streaming: true, + expectModel: "gpt-6.7", expectPromptRecorded: "hello_stream_error", }, { name: "stream_failure", fixture: fixtures.OaiResponsesStreamingStreamFailure, streaming: true, + expectModel: "gpt-6.7", expectPromptRecorded: "hello_stream_failure", }, // Original status code and body is kept even with wrong json format { - name: "blocking_wrong_format", - fixture: fixtures.OaiResponsesBlockingWrongResponseFormat, + name: "blocking_wrong_format", + fixture: fixtures.OaiResponsesBlockingWrongResponseFormat, + expectModel: "gpt-6.7", }, { name: "streaming_wrong_format", fixture: fixtures.OaiResponsesStreamingWrongResponseFormat, streaming: true, + expectModel: "gpt-6.7", expectPromptRecorded: "hello_wrong_format", }, } @@ -125,6 +160,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) + ctx = aibcontext.AsActor(ctx, userID, nil) mockAPI := newMockServer(ctx, t, files, nil) t.Cleanup(mockAPI.Close) @@ -146,12 +182,30 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.NoError(t, err) require.Equal(t, string(files[fixtResp]), string(got)) + interceptions := mockRecorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + intc := interceptions[0] + require.Equal(t, intc.InitiatorID, userID) + require.Equal(t, intc.Provider, config.ProviderOpenAI) + require.Equal(t, intc.Model, tc.expectModel) + + recordedPrompts := mockRecorder.RecordedPromptUsages() if tc.expectPromptRecorded != "" { - recordedPrompts := mockRecorder.RecordedPromptUsages() + require.Len(t, recordedPrompts, 1) promptEq := func(pur *recorder.PromptUsageRecord) bool { return pur.Prompt == tc.expectPromptRecorded } require.Truef(t, slices.ContainsFunc(recordedPrompts, promptEq), "promnt not found, got: %v, want: %v", recordedPrompts, tc.expectPromptRecorded) } else { - require.Empty(t, mockRecorder.RecordedPromptUsages()) + require.Empty(t, recordedPrompts) + } + + recordedTools := mockRecorder.RecordedToolUsages() + if tc.expectToolRecorded != nil { + require.Len(t, recordedTools, 1) + recordedTools[0].InterceptionID = tc.expectToolRecorded.InterceptionID // ignore interception id (interception id is not constant and response doesn't contain it) + recordedTools[0].CreatedAt = tc.expectToolRecorded.CreatedAt // ignore time + require.Equal(t, tc.expectToolRecorded, recordedTools[0]) + } else { + require.Empty(t, recordedTools) } }) }