Skip to content

Commit 11d3c46

Browse files
committed
review changes
1 parent ce0d435 commit 11d3c46

File tree

3 files changed

+32
-16
lines changed

3 files changed

+32
-16
lines changed

intercept/responses/base.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/google/uuid"
2525
"github.com/openai/openai-go/v3/option"
2626
"github.com/openai/openai-go/v3/responses"
27+
oaiconst "github.com/openai/openai-go/v3/shared/constant"
2728
"github.com/tidwall/gjson"
2829
"github.com/tidwall/sjson"
2930
"go.opentelemetry.io/otel/attribute"
@@ -222,17 +223,18 @@ func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, respon
222223

223224
func (i *responsesInterceptionBase) recordToolUsage(ctx context.Context, response *responses.Response) {
224225
if response == nil {
226+
i.logger.Warn(ctx, "got empty response, skipping tool usage recording")
225227
return
226228
}
227229

228230
for _, item := range response.Output {
229231
var args recorder.ToolArgs
230232

231-
// recodring other function types to be considered: https://github.com/coder/aibridge/issues/121
233+
// recording other function types to be considered: https://github.com/coder/aibridge/issues/121
232234
switch item.Type {
233-
case "function_call":
234-
args = i.parseJSONArgs(item.Arguments)
235-
case "custom_tool_call":
235+
case string(oaiconst.ValueOf[oaiconst.FunctionCall]()):
236+
args = i.parseFunctionCallJSONArgs(ctx, item.Arguments)
237+
case string(oaiconst.ValueOf[oaiconst.CustomToolCall]()):
236238
args = item.Input
237239
default:
238240
continue
@@ -250,15 +252,17 @@ func (i *responsesInterceptionBase) recordToolUsage(ctx context.Context, respons
250252
}
251253
}
252254

253-
func (i *responsesInterceptionBase) parseJSONArgs(raw string) recorder.ToolArgs {
254-
if trimmed := strings.TrimSpace(raw); trimmed != "" {
255+
func (i *responsesInterceptionBase) parseFunctionCallJSONArgs(ctx context.Context, raw string) recorder.ToolArgs {
256+
trimmed := strings.TrimSpace(raw)
257+
if trimmed != "" {
255258
var args recorder.ToolArgs
256-
if err := json.Unmarshal([]byte(trimmed), &args); err == nil {
259+
if err := json.Unmarshal([]byte(trimmed), &args); err != nil {
260+
i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err))
261+
} else {
257262
return args
258263
}
259264
}
260-
261-
return nil
265+
return trimmed
262266
}
263267

264268
// responseCopier helper struct to send original response to the client

intercept/responses/base_test.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ func TestRecordToolUsage(t *testing.T) {
237237
InterceptionID: id.String(),
238238
MsgID: "resp_456",
239239
Tool: "get_weather",
240-
Args: nil,
240+
Args: "",
241241
Injected: false,
242242
},
243243
},
@@ -252,6 +252,11 @@ func TestRecordToolUsage(t *testing.T) {
252252
Name: "get_weather",
253253
Arguments: `{"location": "NYC"}`,
254254
},
255+
{
256+
Type: "function_call",
257+
Name: "bad_json_args",
258+
Arguments: `{"bad": args`,
259+
},
255260
{
256261
Type: "message",
257262
ID: "msg_1",
@@ -277,6 +282,13 @@ func TestRecordToolUsage(t *testing.T) {
277282
Args: map[string]any{"location": "NYC"},
278283
Injected: false,
279284
},
285+
{
286+
InterceptionID: id.String(),
287+
MsgID: "resp_789",
288+
Tool: "bad_json_args",
289+
Args: `{"bad": args`,
290+
Injected: false,
291+
},
280292
{
281293
InterceptionID: id.String(),
282294
MsgID: "resp_789",
@@ -329,20 +341,20 @@ func TestParseJSONArgs(t *testing.T) {
329341
{
330342
name: "empty_string",
331343
raw: "",
332-
expected: nil,
344+
expected: "",
333345
},
334346
{
335347
name: "whitespace_only",
336348
raw: " \t\n ",
337-
expected: nil,
349+
expected: "",
338350
},
339351
{
340352
name: "invalid_json",
341353
raw: "{not valid json}",
342-
expected: nil,
354+
expected: "{not valid json}",
343355
},
344356
{
345-
name: "nested_object",
357+
name: "nested_object_with_trailing_spaces",
346358
raw: ` {"user": {"name": "alice", "settings": {"theme": "dark", "notifications": true}}, "count": 42} `,
347359
expected: map[string]any{
348360
"user": map[string]any{
@@ -362,7 +374,7 @@ func TestParseJSONArgs(t *testing.T) {
362374
t.Parallel()
363375

364376
base := &responsesInterceptionBase{}
365-
result := base.parseJSONArgs(tc.raw)
377+
result := base.parseFunctionCallJSONArgs(t.Context(), tc.raw)
366378
require.Equal(t, tc.expected, result)
367379
})
368380
}

responses_integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) {
201201
recordedTools := mockRecorder.RecordedToolUsages()
202202
if tc.expectToolRecorded != nil {
203203
require.Len(t, recordedTools, 1)
204-
recordedTools[0].InterceptionID = tc.expectToolRecorded.InterceptionID // ignore interception id
204+
recordedTools[0].InterceptionID = tc.expectToolRecorded.InterceptionID // ignore interception id (interception id is not constant and response doesn't contain it)
205205
recordedTools[0].CreatedAt = tc.expectToolRecorded.CreatedAt // ignore time
206206
require.Equal(t, tc.expectToolRecorded, recordedTools[0])
207207
} else {

0 commit comments

Comments
 (0)