@@ -2,11 +2,14 @@ package responses
22
33import (
44 "testing"
5+ "time"
56
67 "cdr.dev/slog/v3"
78 "github.com/coder/aibridge/fixtures"
89 "github.com/coder/aibridge/internal/testutil"
10+ "github.com/coder/aibridge/recorder"
911 "github.com/google/uuid"
12+ oairesponses "github.com/openai/openai-go/v3/responses"
1013 "github.com/stretchr/testify/require"
1114)
1215
@@ -194,3 +197,173 @@ func TestRecordPrompt(t *testing.T) {
194197 })
195198 }
196199}
200+
201+ func TestRecordToolUsage (t * testing.T ) {
202+ t .Parallel ()
203+
204+ id := uuid .MustParse ("11111111-1111-1111-1111-111111111111" )
205+
206+ tests := []struct {
207+ name string
208+ response * oairesponses.Response
209+ expected []* recorder.ToolUsageRecord
210+ }{
211+ {
212+ name : "nil_response" ,
213+ response : nil ,
214+ expected : nil ,
215+ },
216+ {
217+ name : "empty_output" ,
218+ response : & oairesponses.Response {
219+ ID : "resp_123" ,
220+ },
221+ expected : nil ,
222+ },
223+ {
224+ name : "empty_tool_args" ,
225+ response : & oairesponses.Response {
226+ ID : "resp_456" ,
227+ Output : []oairesponses.ResponseOutputItemUnion {
228+ {
229+ Type : "function_call" ,
230+ Name : "get_weather" ,
231+ Arguments : "" ,
232+ },
233+ },
234+ },
235+ expected : []* recorder.ToolUsageRecord {
236+ {
237+ InterceptionID : id .String (),
238+ MsgID : "resp_456" ,
239+ Tool : "get_weather" ,
240+ Args : nil ,
241+ Injected : false ,
242+ },
243+ },
244+ },
245+ {
246+ name : "multiple_tool_calls" ,
247+ response : & oairesponses.Response {
248+ ID : "resp_789" ,
249+ Output : []oairesponses.ResponseOutputItemUnion {
250+ {
251+ Type : "function_call" ,
252+ Name : "get_weather" ,
253+ Arguments : `{"location": "NYC"}` ,
254+ },
255+ {
256+ Type : "message" ,
257+ ID : "msg_1" ,
258+ Role : "assistant" ,
259+ },
260+ {
261+ Type : "custom_tool_call" ,
262+ Name : "search" ,
263+ Input : `{\"query\": \"test\"}` ,
264+ },
265+ {
266+ Type : "function_call" ,
267+ Name : "calculate" ,
268+ Arguments : `{"a": 1, "b": 2}` ,
269+ },
270+ },
271+ },
272+ expected : []* recorder.ToolUsageRecord {
273+ {
274+ InterceptionID : id .String (),
275+ MsgID : "resp_789" ,
276+ Tool : "get_weather" ,
277+ Args : map [string ]any {"location" : "NYC" },
278+ Injected : false ,
279+ },
280+ {
281+ InterceptionID : id .String (),
282+ MsgID : "resp_789" ,
283+ Tool : "search" ,
284+ Args : `{\"query\": \"test\"}` ,
285+ Injected : false ,
286+ },
287+ {
288+ InterceptionID : id .String (),
289+ MsgID : "resp_789" ,
290+ Tool : "calculate" ,
291+ Args : map [string ]any {"a" : float64 (1 ), "b" : float64 (2 )},
292+ Injected : false ,
293+ },
294+ },
295+ },
296+ }
297+
298+ for _ , tc := range tests {
299+ t .Run (tc .name , func (t * testing.T ) {
300+ t .Parallel ()
301+
302+ rec := & testutil.MockRecorder {}
303+ base := & responsesInterceptionBase {
304+ id : id ,
305+ recorder : rec ,
306+ logger : slog .Make (),
307+ }
308+
309+ base .recordToolUsage (t .Context (), tc .response )
310+
311+ tools := rec .RecordedToolUsages ()
312+ require .Len (t , tools , len (tc .expected ))
313+ for i , got := range tools {
314+ got .CreatedAt = time.Time {}
315+ require .Equal (t , tc .expected [i ], got )
316+ }
317+ })
318+ }
319+ }
320+
321+ func TestParseJSONArgs (t * testing.T ) {
322+ t .Parallel ()
323+
324+ tests := []struct {
325+ name string
326+ raw string
327+ expected recorder.ToolArgs
328+ }{
329+ {
330+ name : "empty_string" ,
331+ raw : "" ,
332+ expected : nil ,
333+ },
334+ {
335+ name : "whitespace_only" ,
336+ raw : " \t \n " ,
337+ expected : nil ,
338+ },
339+ {
340+ name : "invalid_json" ,
341+ raw : "{not valid json}" ,
342+ expected : nil ,
343+ },
344+ {
345+ name : "nested_object" ,
346+ raw : ` {"user": {"name": "alice", "settings": {"theme": "dark", "notifications": true}}, "count": 42} ` ,
347+ expected : map [string ]any {
348+ "user" : map [string ]any {
349+ "name" : "alice" ,
350+ "settings" : map [string ]any {
351+ "theme" : "dark" ,
352+ "notifications" : true ,
353+ },
354+ },
355+ "count" : float64 (42 ),
356+ },
357+ },
358+ }
359+
360+ for _ , tc := range tests {
361+ t .Run (tc .name , func (t * testing.T ) {
362+ t .Parallel ()
363+
364+ base := & responsesInterceptionBase {}
365+ result := base .parseJSONArgs (tc .raw )
366+ require .Equal (t , tc .expected , result )
367+ })
368+ }
369+ }
0 commit comments