Skip to content

Commit 3f4ec16

Browse files
authored
fix: do not relay partial tool call events to openai client (#8)
Signed-off-by: Danny Kopping <[email protected]>
1 parent c5af652 commit 3f4ec16

File tree

2 files changed

+260
-101
lines changed

2 files changed

+260
-101
lines changed

bridge_integration_test.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,18 +688,44 @@ func TestOpenAIInjectedTools(t *testing.T) {
688688
decoder := oaissestream.NewDecoder(resp)
689689
stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil)
690690
var acc openai.ChatCompletionAccumulator
691+
detectedToolCalls := make(map[string]struct{})
691692
for stream.Next() {
692693
chunk := stream.Current()
693694
acc.AddChunk(chunk)
695+
696+
if len(chunk.Choices) == 0 {
697+
continue
698+
}
699+
700+
for _, c := range chunk.Choices {
701+
if len(c.Delta.ToolCalls) == 0 {
702+
continue
703+
}
704+
705+
for _, t := range c.Delta.ToolCalls {
706+
if t.Function.Name == "" {
707+
continue
708+
}
709+
710+
detectedToolCalls[t.Function.Name] = struct{}{}
711+
}
712+
}
694713
}
695714

715+
// Verify that no injected tool call events (or partials thereof) were sent to the client.
716+
require.Len(t, detectedToolCalls, 0)
717+
696718
message = acc.ChatCompletion
697719
require.NoError(t, stream.Err(), "stream error")
698720
} else {
699721
// Parse & unmarshal the response.
700722
body, err := io.ReadAll(resp.Body)
701723
require.NoError(t, err, "read response body")
702724
require.NoError(t, json.Unmarshal(body, &message), "unmarshal response")
725+
726+
// Verify that no injected tools were sent to the client.
727+
require.GreaterOrEqual(t, len(message.Choices), 1)
728+
require.Len(t, message.Choices[0].Message.ToolCalls, 0)
703729
}
704730

705731
require.GreaterOrEqual(t, len(message.Choices), 1)
@@ -796,7 +822,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
796822
// We must ALWAYS have 2 calls to the bridge for injected tool tests.
797823
require.Eventually(t, func() bool {
798824
return mockSrv.callCount.Load() == 2
799-
}, time.Second*25, time.Millisecond*50)
825+
}, time.Second*10, time.Millisecond*50)
800826

801827
return recorderClient, resp
802828
}

0 commit comments

Comments
 (0)