Skip to content

Commit e0cbd6d

Browse files
hustxiayangsecustoryduwcuiydu208yuzisun
authored
fix: set index of tool calls in stream mode (#1468)
**Description** There is an `index` field (https://github.com/openai/openai-python/blob/4e8856576211064b09c0cc4a1ed35b82b169abe2/src/openai/types/chat/chat_completion_chunk.py#L48) in the tool calls definition in stream mode: and it's not optional. Third party libs like opentelemetry(https://github.com/open-telemetry/opentelemetry-python-contrib/blob/447aac2b0fbe88998e970941c1160147f355ec73/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py#L285) is implemented based on this assumption. The index field should be set according the number of tool calls in the streaming events which can be tracked by checking message start/stop event types. - aws bedrock: track the tool index by contentBlockStart/Stop events - gcp anthropic: track the tool index by content_block_start/stop events **Discuss/suggestion**: - `ChatCompletionMessageToolCallParam` should not be reused in the response type. - For both converse api and anthropic api, the stream chunks have strong dependencies. We should use complete responses to write unit tests. --------- Signed-off-by: yxia216 <[email protected]> Signed-off-by: secustor <[email protected]> Signed-off-by: ydu208 <[email protected]> Signed-off-by: Takeshi Yoneda <[email protected]> Signed-off-by: Dan Sun <[email protected]> Co-authored-by: Sebastian Poxhofer <[email protected]> Co-authored-by: yuhongd <[email protected]> Co-authored-by: ydu208 <[email protected]> Co-authored-by: Dan Sun <[email protected]> Co-authored-by: Takeshi Yoneda <[email protected]>
1 parent c87e367 commit e0cbd6d

File tree

12 files changed

+737
-169
lines changed

12 files changed

+737
-169
lines changed

internal/apischema/awsbedrock/awsbedrock.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,44 @@ type TokenUsage struct {
403403
CacheWriteInputTokens *int `json:"cacheWriteInputTokens,omitempty"`
404404
}
405405

406+
// ConverseStreamEventType represents a distinct event type received from the Bedrock ConverseStream API.
407+
// Using a string type provides a clear, human-readable value, which is often helpful
408+
// when logging or debugging stream events.
409+
type ConverseStreamEventType string
410+
411+
const (
412+
// ConverseStreamEventTypeUnknown is the zero value and represents an uninitialized or unknown event type.
413+
// This is a Go best practice to ensure the zero value is not a valid constant.
414+
ConverseStreamEventTypeUnknown ConverseStreamEventType = ""
415+
416+
// ConverseStreamEventTypeMessageStart signals the start of the assistant's message.
417+
ConverseStreamEventTypeMessageStart ConverseStreamEventType = "messageStart"
418+
419+
// ConverseStreamEventTypeContentBlockStart signals the start of a content block within the message (e.g., text or tool use).
420+
ConverseStreamEventTypeContentBlockStart ConverseStreamEventType = "contentBlockStart"
421+
422+
// ConverseStreamEventTypeContentBlockDelta contains a chunk of content (e.g., text or tool input).
423+
ConverseStreamEventTypeContentBlockDelta ConverseStreamEventType = "contentBlockDelta"
424+
425+
// ConverseStreamEventTypeContentBlockStop signals the end of a content block.
426+
ConverseStreamEventTypeContentBlockStop ConverseStreamEventType = "contentBlockStop"
427+
428+
// ConverseStreamEventTypeMessageStop signals the end of the entire message.
429+
ConverseStreamEventTypeMessageStop ConverseStreamEventType = "messageStop"
430+
431+
// ConverseStreamEventTypeMetadata contains usage and latency information.
432+
ConverseStreamEventTypeMetadata ConverseStreamEventType = "metadata"
433+
)
434+
435+
// String implements the fmt.Stringer interface for better logging/printing.
436+
func (c ConverseStreamEventType) String() string {
437+
return string(c)
438+
}
439+
406440
// ConverseStreamEvent is the union of all possible event types in the AWS Bedrock API:
407441
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html
408442
type ConverseStreamEvent struct {
443+
EventType string `json:"eventType"`
409444
ContentBlockIndex int `json:"contentBlockIndex,omitempty"`
410445
Delta *ConverseStreamEventContentBlockDelta `json:"delta,omitempty"`
411446
Role *string `json:"role,omitempty"`

internal/apischema/openai/openai.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,6 @@ type ChatCompletionMessageToolCallFunctionParam struct {
546546
}
547547

548548
type ChatCompletionMessageToolCallParam struct {
549-
// Add this Index field. It is required for streaming.
550-
Index *int `json:"index,omitempty"`
551549
// The ID of the tool call.
552550
ID *string `json:"id"`
553551
// The function that the model called.
@@ -1350,14 +1348,24 @@ type ChatCompletionResponseChunkChoice struct {
13501348
FinishReason ChatCompletionChoicesFinishReason `json:"finish_reason,omitempty"`
13511349
}
13521350

1351+
type ChatCompletionChunkChoiceDeltaToolCall struct {
1352+
Index int64 `json:"index"`
1353+
// The ID of the tool call.
1354+
ID *string `json:"id"`
1355+
// The function that the model called.
1356+
Function ChatCompletionMessageToolCallFunctionParam `json:"function"`
1357+
// The type of the tool. Currently, only `function` is supported.
1358+
Type ChatCompletionMessageToolCallType `json:"type,omitempty"`
1359+
}
1360+
13531361
// ChatCompletionResponseChunkChoiceDelta is described in the OpenAI API documentation:
13541362
// https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-choices
13551363
type ChatCompletionResponseChunkChoiceDelta struct {
1356-
Content *string `json:"content,omitempty"`
1357-
Role string `json:"role,omitempty"`
1358-
ToolCalls []ChatCompletionMessageToolCallParam `json:"tool_calls,omitempty"`
1359-
Annotations *[]Annotation `json:"annotations,omitempty"`
1360-
ReasoningContent *AWSBedrockStreamReasoningContent `json:"reasoning_content,omitempty"`
1364+
Content *string `json:"content,omitempty"`
1365+
Role string `json:"role,omitempty"`
1366+
ToolCalls []ChatCompletionChunkChoiceDeltaToolCall `json:"tool_calls,omitempty"`
1367+
Annotations *[]Annotation `json:"annotations,omitempty"`
1368+
ReasoningContent *AWSBedrockStreamReasoningContent `json:"reasoning_content,omitempty"`
13611369
}
13621370

13631371
// Error is described in the OpenAI API documentation

internal/apischema/openai/openai_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,19 +1288,20 @@ func TestChatCompletionResponseChunkChoice(t *testing.T) {
12881288
Index: 0,
12891289
Delta: &ChatCompletionResponseChunkChoiceDelta{
12901290
Role: "assistant",
1291-
ToolCalls: []ChatCompletionMessageToolCallParam{
1291+
ToolCalls: []ChatCompletionChunkChoiceDeltaToolCall{
12921292
{
12931293
ID: ptr.To("tooluse_QklrEHKjRu6Oc4BQUfy7ZQ"),
12941294
Type: "function",
12951295
Function: ChatCompletionMessageToolCallFunctionParam{
12961296
Name: "cosine",
12971297
Arguments: "",
12981298
},
1299+
Index: 0,
12991300
},
13001301
},
13011302
},
13021303
},
1303-
expected: `{"index":0,"delta":{"role":"assistant","tool_calls":[{"id":"tooluse_QklrEHKjRu6Oc4BQUfy7ZQ","function":{"arguments":"","name":"cosine"},"type":"function"}]}}`,
1304+
expected: `{"index":0,"delta":{"role":"assistant","tool_calls":[{"id":"tooluse_QklrEHKjRu6Oc4BQUfy7ZQ","function":{"arguments":"","name":"cosine"},"type":"function", "index": 0}]}}`,
13041305
},
13051306
{
13061307
name: "streaming chunk with annotations",

internal/extproc/translator/gemini_helper.go

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,48 @@ func extractToolCallsFromGeminiParts(parts []*genai.Part) ([]openai.ChatCompleti
647647
return toolCalls, nil
648648
}
649649

650+
// extractToolCallsFromGeminiPartsStream extracts tool calls from Gemini parts for streaming responses.
651+
// Each tool call is assigned an incremental index starting from 0, matching OpenAI's streaming protocol.
652+
// Returns ChatCompletionChunkChoiceDeltaToolCall types suitable for streaming responses, or nil if no tool calls are found.
653+
func extractToolCallsFromGeminiPartsStream(parts []*genai.Part) ([]openai.ChatCompletionChunkChoiceDeltaToolCall, error) {
654+
var toolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall
655+
toolCallIndex := int64(0)
656+
657+
for _, part := range parts {
658+
if part == nil || part.FunctionCall == nil {
659+
continue
660+
}
661+
662+
// Convert function call arguments to JSON string.
663+
args, err := json.Marshal(part.FunctionCall.Args)
664+
if err != nil {
665+
return nil, fmt.Errorf("failed to marshal function arguments: %w", err)
666+
}
667+
668+
// Generate a random ID for the tool call.
669+
toolCallID := uuid.New().String()
670+
671+
toolCall := openai.ChatCompletionChunkChoiceDeltaToolCall{
672+
ID: &toolCallID,
673+
Type: openai.ChatCompletionMessageToolCallTypeFunction,
674+
Function: openai.ChatCompletionMessageToolCallFunctionParam{
675+
Name: part.FunctionCall.Name,
676+
Arguments: string(args),
677+
},
678+
Index: toolCallIndex,
679+
}
680+
681+
toolCalls = append(toolCalls, toolCall)
682+
toolCallIndex++
683+
}
684+
685+
if len(toolCalls) == 0 {
686+
return nil, nil
687+
}
688+
689+
return toolCalls, nil
690+
}
691+
650692
// geminiUsageToOpenAIUsage converts Gemini usage metadata to OpenAI usage.
651693
func geminiUsageToOpenAIUsage(metadata *genai.GenerateContentResponseUsageMetadata) openai.Usage {
652694
if metadata == nil {
@@ -746,7 +788,7 @@ func geminiCandidatesToOpenAIStreamingChoices(candidates []*genai.Candidate, res
746788
}
747789

748790
// Extract tool calls if any.
749-
toolCalls, err := extractToolCallsFromGeminiParts(candidate.Content.Parts)
791+
toolCalls, err := extractToolCallsFromGeminiPartsStream(candidate.Content.Parts)
750792
if err != nil {
751793
return nil, fmt.Errorf("error extracting tool calls: %w", err)
752794
}

0 commit comments

Comments
 (0)