From 30c3e4db691d3697c612e224df46d080dbebd01b Mon Sep 17 00:00:00 2001 From: yxia216 Date: Wed, 3 Dec 2025 01:31:38 -0500 Subject: [PATCH 1/2] init Signed-off-by: yxia216 --- internal/apischema/openai/openai.go | 42 +++- internal/translator/openai_awsbedrock.go | 17 +- internal/translator/openai_awsbedrock_test.go | 2 +- internal/translator/openai_gcpanthropic.go | 72 +++++-- .../translator/openai_gcpanthropic_test.go | 198 +++++++++++++++++- 5 files changed, 307 insertions(+), 24 deletions(-) diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go index d1dd9c3a7..a3cad88ee 100644 --- a/internal/apischema/openai/openai.go +++ b/internal/apischema/openai/openai.go @@ -10,6 +10,7 @@ package openai import ( "bytes" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -515,8 +516,8 @@ type ChatCompletionAssistantMessageParamContent struct { Text *string `json:"text,omitempty"` // The signature for a thinking block. - Signature *string `json:"signature,omitempty"` - RedactedContent []byte `json:"redactedContent,omitempty"` + Signature *string `json:"signature,omitempty"` + RedactedContent *RedactedContentUnion `json:"redactedContent,omitempty"` *AnthropicContentFields `json:",inline,omitempty"` } @@ -1583,6 +1584,43 @@ func (e EmbeddingUnion) MarshalJSON() ([]byte, error) { return json.Marshal(e.Value) } +// RedactedContentUnion is a union type that can handle both []byte and string formats. +// AWS Bedrock uses []byte while GCP Anthropic uses string. +type RedactedContentUnion struct { + Value any +} + +// UnmarshalJSON implements json.Unmarshaler to handle both []byte and string formats. +func (r *RedactedContentUnion) UnmarshalJSON(data []byte) error { + // Try to unmarshal as []byte first (base64 encoded). + var str string + if err := json.Unmarshal(data, &str); err == nil { + // Try to decode as base64 first (this would be []byte encoded as base64) + if decoded, err := base64.StdEncoding.DecodeString(str); err == nil { + r.Value = decoded + return nil + } + // If not base64, treat as plain string + r.Value = str + return nil + } + + return errors.New("redactedContent must be either []byte (base64 encoded) or string") +} + +// MarshalJSON implements json.Marshaler. +func (r RedactedContentUnion) MarshalJSON() ([]byte, error) { + switch v := r.Value.(type) { + case []byte: + // Encode []byte as base64 string + return json.Marshal(base64.StdEncoding.EncodeToString(v)) + case string: + return json.Marshal(v) + default: + return json.Marshal(r.Value) + } +} + // EmbeddingUsage represents the usage information for an embeddings request. // https://platform.openai.com/docs/api-reference/embeddings/object#embeddings/object-usage type EmbeddingUsage struct { diff --git a/internal/translator/openai_awsbedrock.go b/internal/translator/openai_awsbedrock.go index d8e454e1a..63b4dd1d4 100644 --- a/internal/translator/openai_awsbedrock.go +++ b/internal/translator/openai_awsbedrock.go @@ -312,11 +312,18 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIMessageToBedrockMes } case openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking: if content.RedactedContent != nil { - contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{ - ReasoningContent: &awsbedrock.ReasoningContentBlock{ - RedactedContent: content.RedactedContent, - }, - }) + switch v := content.RedactedContent.Value.(type) { + case []byte: + contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + RedactedContent: v, + }, + }) + case string: + return nil, fmt.Errorf("AWS Bedrock does not support string format for RedactedContent, expected []byte") + default: + return nil, fmt.Errorf("unsupported RedactedContent type: %T, expected []byte", v) + } } case openai.ChatCompletionAssistantMessageParamContentTypeRefusal: if content.Refusal != nil { diff --git a/internal/translator/openai_awsbedrock_test.go b/internal/translator/openai_awsbedrock_test.go index 01d479775..ecf56730a 100644 --- a/internal/translator/openai_awsbedrock_test.go +++ b/internal/translator/openai_awsbedrock_test.go @@ -986,7 +986,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T) Value: []openai.ChatCompletionAssistantMessageParamContent{ { Type: openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking, - RedactedContent: []byte{104, 101, 108, 108, 111}, + RedactedContent: &openai.RedactedContentUnion{Value: []byte{104, 101, 108, 108, 111}}, }, }, }, diff --git a/internal/translator/openai_gcpanthropic.go b/internal/translator/openai_gcpanthropic.go index 60873b075..60550819e 100644 --- a/internal/translator/openai_gcpanthropic.go +++ b/internal/translator/openai_gcpanthropic.go @@ -375,6 +375,46 @@ func anthropicRoleToOpenAIRole(role anthropic.MessageParamRole) (string, error) } } +// processAssistantContent processes a single ChatCompletionAssistantMessageParamContent and returns the corresponding Anthropic content block. +func processAssistantContent(content openai.ChatCompletionAssistantMessageParamContent) (*anthropic.ContentBlockParamUnion, error) { + switch content.Type { + case openai.ChatCompletionAssistantMessageParamContentTypeRefusal: + if content.Refusal != nil { + block := anthropic.NewTextBlock(*content.Refusal) + return &block, nil + } + case openai.ChatCompletionAssistantMessageParamContentTypeText: + if content.Text != nil { + textBlock := anthropic.NewTextBlock(*content.Text) + if isCacheEnabled(content.AnthropicContentFields) { + textBlock.OfText.CacheControl = content.CacheControl + } + return &textBlock, nil + } + case openai.ChatCompletionAssistantMessageParamContentTypeThinking: + // thinking can not be cached: https://platform.claude.com/docs/en/build-with-claude/prompt-caching + if content.Text != nil && content.Signature != nil { + thinkBlock := anthropic.NewThinkingBlock(*content.Text, *content.Signature) + return &thinkBlock, nil + } + case openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking: + if content.RedactedContent != nil { + switch v := content.RedactedContent.Value.(type) { + case string: + redactedThinkingBlock := anthropic.NewRedactedThinkingBlock(v) + return &redactedThinkingBlock, nil + case []byte: + return nil, fmt.Errorf("GCP Anthropic does not support []byte format for RedactedContent, expected string") + default: + return nil, fmt.Errorf("unsupported RedactedContent type: %T, expected string", v) + } + } + default: + return nil, fmt.Errorf("content type not supported: %v", content.Type) + } + return nil, nil +} + // openAIMessageToAnthropicMessageRoleAssistant converts an OpenAI assistant message to Anthropic content blocks. // The tool_use content is appended to the Anthropic message content list if tool_calls are present. func openAIMessageToAnthropicMessageRoleAssistant(openAiMessage *openai.ChatCompletionAssistantMessageParam) (anthropicMsg anthropic.MessageParam, err error) { @@ -382,22 +422,24 @@ func openAIMessageToAnthropicMessageRoleAssistant(openAiMessage *openai.ChatComp if v, ok := openAiMessage.Content.Value.(string); ok && len(v) > 0 { contentBlocks = append(contentBlocks, anthropic.NewTextBlock(v)) } else if content, ok := openAiMessage.Content.Value.(openai.ChatCompletionAssistantMessageParamContent); ok { - switch content.Type { - case openai.ChatCompletionAssistantMessageParamContentTypeRefusal: - if content.Refusal != nil { - contentBlocks = append(contentBlocks, anthropic.NewTextBlock(*content.Refusal)) - } - case openai.ChatCompletionAssistantMessageParamContentTypeText: - if content.Text != nil { - textBlock := anthropic.NewTextBlock(*content.Text) - if isCacheEnabled(content.AnthropicContentFields) { - textBlock.OfText.CacheControl = content.CacheControl - } - contentBlocks = append(contentBlocks, textBlock) + // Handle single content object + var block *anthropic.ContentBlockParamUnion + block, err = processAssistantContent(content) + if err != nil { + return anthropicMsg, err + } else if block != nil { + contentBlocks = append(contentBlocks, *block) + } + } else if contents, ok := openAiMessage.Content.Value.([]openai.ChatCompletionAssistantMessageParamContent); ok { + // Handle array of content objects + for _, content := range contents { + var block *anthropic.ContentBlockParamUnion + block, err = processAssistantContent(content) + if err != nil { + return anthropicMsg, err + } else if block != nil { + contentBlocks = append(contentBlocks, *block) } - default: - err = fmt.Errorf("content type not supported: %v", content.Type) - return } } diff --git a/internal/translator/openai_gcpanthropic_test.go b/internal/translator/openai_gcpanthropic_test.go index 3044eac5f..740f9b188 100644 --- a/internal/translator/openai_gcpanthropic_test.go +++ b/internal/translator/openai_gcpanthropic_test.go @@ -7,6 +7,7 @@ package translator import ( "bytes" + "encoding/base64" "encoding/json" "fmt" "io" @@ -788,6 +789,147 @@ func TestMessageTranslation(t *testing.T) { }, }, }, + { + name: "assistant message with thinking content", + inputMessages: []openai.ChatCompletionMessageParamUnion{ + { + OfAssistant: &openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: []openai.ChatCompletionAssistantMessageParamContent{ + { + Type: openai.ChatCompletionAssistantMessageParamContentTypeThinking, + Text: ptr.To("Let me think about this step by step..."), + Signature: ptr.To("signature-123"), + }, + }, + }, + Role: openai.ChatMessageRoleAssistant, + }, + }, + }, + expectedAnthropicMsgs: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleAssistant, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewThinkingBlock("Let me think about this step by step...", "signature-123"), + }, + }, + }, + }, + { + name: "assistant message with thinking content missing signature", + inputMessages: []openai.ChatCompletionMessageParamUnion{ + { + OfAssistant: &openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: []openai.ChatCompletionAssistantMessageParamContent{ + { + Type: openai.ChatCompletionAssistantMessageParamContentTypeThinking, + Text: ptr.To("Let me think about this step by step..."), + // Missing signature - should not create thinking block + }, + }, + }, + Role: openai.ChatMessageRoleAssistant, + }, + }, + }, + expectedAnthropicMsgs: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleAssistant, + Content: []anthropic.ContentBlockParamUnion{}, + }, + }, + }, + { + name: "assistant message with thinking content missing text", + inputMessages: []openai.ChatCompletionMessageParamUnion{ + { + OfAssistant: &openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: []openai.ChatCompletionAssistantMessageParamContent{ + { + Type: openai.ChatCompletionAssistantMessageParamContentTypeThinking, + Signature: ptr.To("signature-123"), + // Missing text - should not create thinking block + }, + }, + }, + Role: openai.ChatMessageRoleAssistant, + }, + }, + }, + expectedAnthropicMsgs: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleAssistant, + Content: []anthropic.ContentBlockParamUnion{}, + }, + }, + }, + { + name: "assistant message with redacted thinking content (string)", + inputMessages: []openai.ChatCompletionMessageParamUnion{ + { + OfAssistant: &openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: []openai.ChatCompletionAssistantMessageParamContent{ + { + Type: openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking, + RedactedContent: &openai.RedactedContentUnion{Value: "redacted content as string"}, + }, + }, + }, + Role: openai.ChatMessageRoleAssistant, + }, + }, + }, + expectedAnthropicMsgs: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleAssistant, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewRedactedThinkingBlock("redacted content as string"), + }, + }, + }, + }, + { + name: "assistant message with redacted thinking content ([]byte) - should fail", + inputMessages: []openai.ChatCompletionMessageParamUnion{ + { + OfAssistant: &openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: []openai.ChatCompletionAssistantMessageParamContent{ + { + Type: openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking, + RedactedContent: &openai.RedactedContentUnion{Value: []byte("redacted content as bytes")}, + }, + }, + }, + Role: openai.ChatMessageRoleAssistant, + }, + }, + }, + expectErr: true, + }, + { + name: "assistant message with redacted thinking content (unsupported type) - should fail", + inputMessages: []openai.ChatCompletionMessageParamUnion{ + { + OfAssistant: &openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: []openai.ChatCompletionAssistantMessageParamContent{ + { + Type: openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking, + RedactedContent: &openai.RedactedContentUnion{Value: 123}, + }, + }, + }, + Role: openai.ChatMessageRoleAssistant, + }, + }, + }, + expectErr: true, + }, } for _, tt := range tests { @@ -848,7 +990,61 @@ func TestMessageTranslation(t *testing.T) { } } -func TestOpenAIToGCPAnthropicTranslator_ResponseError(t *testing.T) { +// TestRedactedContentUnionSerialization tests the JSON marshaling/unmarshaling of RedactedContentUnion +func TestRedactedContentUnionSerialization(t *testing.T) { + tests := []struct { + name string + input string + expectedValue any + expectError bool + }{ + { + name: "string value", + input: `"plain string"`, + expectedValue: "plain string", + }, + { + name: "base64 encoded bytes", + input: `"aGVsbG8gd29ybGQ="`, // "hello world" in base64 + expectedValue: []byte("hello world"), + }, + { + name: "invalid json", + input: `{invalid}`, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var union openai.RedactedContentUnion + err := json.Unmarshal([]byte(tt.input), &union) + + if tt.expectError { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.expectedValue, union.Value) + + // Test marshaling back + marshaled, err := json.Marshal(union) + require.NoError(t, err) + + // For byte arrays, check they're base64 encoded + if bytes, ok := tt.expectedValue.([]byte); ok { + expected := base64.StdEncoding.EncodeToString(bytes) + require.Equal(t, `"`+expected+`"`, string(marshaled)) + } else { + // For strings, check round-trip + require.Equal(t, tt.input, string(marshaled)) + } + }) + } +} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseError(t *testing.T) { tests := []struct { name string responseHeaders map[string]string From e09547c4800c9265fdb7d8b0fba4658a3b774d9e Mon Sep 17 00:00:00 2001 From: yxia216 Date: Thu, 4 Dec 2025 15:20:07 -0500 Subject: [PATCH 2/2] update Signed-off-by: yxia216 --- cmd/aigw/run.go | 17 +- internal/translator/openai_gcpanthropic.go | 45 ++++- .../translator/openai_gcpanthropic_stream.go | 40 +++-- .../openai_gcpanthropic_stream_test.go | 156 +++++++++++++++- .../translator/openai_gcpanthropic_test.go | 88 +++++++++ tests/bench/aigw.yaml | 123 +++++++++++++ tests/bench/bench_test.go | 167 ++++++++++++++++++ tests/internal/testmcp/server.go | 4 + 8 files changed, 604 insertions(+), 36 deletions(-) create mode 100644 tests/bench/aigw.yaml create mode 100644 tests/bench/bench_test.go diff --git a/cmd/aigw/run.go b/cmd/aigw/run.go index d430196d1..de7dbdbe8 100644 --- a/cmd/aigw/run.go +++ b/cmd/aigw/run.go @@ -136,14 +136,15 @@ func run(ctx context.Context, c cmdRun, o *runOpts, stdout, stderr io.Writer) er // Do the translation of the given AI Gateway resources Yaml into Envoy Gateway resources and write them to the file. resourcesBuf := &bytes.Buffer{} runCtx := &runCmdContext{ - isDebug: c.Debug, - envoyGatewayResourcesOut: resourcesBuf, - stderrLogger: debugLogger, - stderr: stderr, - tmpdir: filepath.Dir(o.logPath), // runDir - udsPath: o.extprocUDSPath, - adminPort: c.AdminPort, - extProcLauncher: o.extProcLauncher, + isDebug: c.Debug, + envoyGatewayResourcesOut: resourcesBuf, + stderrLogger: debugLogger, + stderr: stderr, + tmpdir: filepath.Dir(o.logPath), // runDir + udsPath: o.extprocUDSPath, + adminPort: c.AdminPort, + extProcLauncher: o.extProcLauncher, + mcpSessionEncryptionIterations: c.MCPSessionEncryptionIterations, } // If any of the configured MCP servers is using stdio, set up the streamable HTTP proxies for them if err = proxyStdioMCPServers(ctx, debugLogger, c.mcpConfig); err != nil { diff --git a/internal/translator/openai_gcpanthropic.go b/internal/translator/openai_gcpanthropic.go index 60550819e..291fbcdb9 100644 --- a/internal/translator/openai_gcpanthropic.go +++ b/internal/translator/openai_gcpanthropic.go @@ -22,6 +22,7 @@ import ( openAIconstant "github.com/openai/openai-go/shared/constant" "github.com/tidwall/sjson" + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/metrics" @@ -859,15 +860,43 @@ func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) ResponseBody(_ map[stri for i := range anthropicResp.Content { // NOTE: Content structure is massive, do not range over values. output := &anthropicResp.Content[i] - if output.Type == string(constant.ValueOf[constant.ToolUse]()) && output.ID != "" { - toolCalls, toolErr := anthropicToolUseToOpenAICalls(output) - if toolErr != nil { - return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to convert anthropic tool use to openai tool call: %w", toolErr) + switch output.Type { + case string(constant.ValueOf[constant.ToolUse]()): + if output.ID != "" { + toolCalls, toolErr := anthropicToolUseToOpenAICalls(output) + if toolErr != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to convert anthropic tool use to openai tool call: %w", toolErr) + } + choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCalls...) + } + case string(constant.ValueOf[constant.Text]()): + if output.Text != "" { + if choice.Message.Content == nil { + choice.Message.Content = &output.Text + } } - choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCalls...) - } else if output.Type == string(constant.ValueOf[constant.Text]()) && output.Text != "" { - if choice.Message.Content == nil { - choice.Message.Content = &output.Text + case string(constant.ValueOf[constant.Thinking]()): + if output.Thinking != "" { + choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + ReasoningText: &awsbedrock.ReasoningTextBlock{ + Text: output.Thinking, + Signature: output.Signature, + }, + }, + }, + } + } + case string(constant.ValueOf[constant.RedactedThinking]()): + if output.Data != "" { + choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + RedactedContent: []byte(output.Data), + }, + }, + } } } } diff --git a/internal/translator/openai_gcpanthropic_stream.go b/internal/translator/openai_gcpanthropic_stream.go index d43679ce1..0f9a841ba 100644 --- a/internal/translator/openai_gcpanthropic_stream.go +++ b/internal/translator/openai_gcpanthropic_stream.go @@ -14,7 +14,6 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared/constant" - "k8s.io/utils/ptr" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" "github.com/envoyproxy/ai-gateway/internal/internalapi" @@ -22,10 +21,7 @@ import ( tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" ) -var ( - sseEventPrefix = []byte("event:") - emptyStrPtr = ptr.To("") -) +var sseEventPrefix = []byte("event:") // streamingToolCall holds the state for a single tool call that is being streamed. type streamingToolCall struct { @@ -265,16 +261,7 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat } return p.constructOpenAIChatCompletionChunk(delta, ""), nil } - if event.ContentBlock.Type == string(constant.ValueOf[constant.Thinking]()) { - delta := openai.ChatCompletionResponseChunkChoiceDelta{Content: emptyStrPtr} - return p.constructOpenAIChatCompletionChunk(delta, ""), nil - } - - if event.ContentBlock.Type == string(constant.ValueOf[constant.RedactedThinking]()) { - // This is a latency-hiding event, ignore it. - return nil, nil - } - + // do not need to return an empty str for thinking start block return nil, nil case string(constant.ValueOf[constant.MessageDelta]()): @@ -304,10 +291,28 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat return nil, fmt.Errorf("unmarshal content_block_delta: %w", err) } switch event.Delta.Type { - case string(constant.ValueOf[constant.TextDelta]()), string(constant.ValueOf[constant.ThinkingDelta]()): - // Treat thinking_delta just like a text_delta. + case string(constant.ValueOf[constant.TextDelta]()): delta := openai.ChatCompletionResponseChunkChoiceDelta{Content: &event.Delta.Text} return p.constructOpenAIChatCompletionChunk(delta, ""), nil + + case string(constant.ValueOf[constant.ThinkingDelta]()): + // this should already include the case for redacted thinking: https://platform.claude.com/docs/en/build-with-claude/streaming#content-block-delta-types + + reasoningDelta := &openai.StreamReasoningContent{} + + // Map all relevant fields from the Bedrock delta to our flattened OpenAI delta struct. + if event.Delta.Thinking != "" { + reasoningDelta.Text = event.Delta.Thinking + } + if event.Delta.Signature != "" { + reasoningDelta.Signature = event.Delta.Signature + } + + delta := openai.ChatCompletionResponseChunkChoiceDelta{ + ReasoningContent: reasoningDelta, + } + return p.constructOpenAIChatCompletionChunk(delta, ""), nil + case string(constant.ValueOf[constant.InputJSONDelta]()): tool, ok := p.activeToolCalls[p.toolIndex] if !ok { @@ -326,6 +331,7 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat tool.inputJSON += event.Delta.PartialJSON return p.constructOpenAIChatCompletionChunk(delta, ""), nil } + // Do not process redacted thinking stream? Did not find the source case string(constant.ValueOf[constant.ContentBlockStop]()): // This event is for state cleanup, no chunk is sent. diff --git a/internal/translator/openai_gcpanthropic_stream_test.go b/internal/translator/openai_gcpanthropic_stream_test.go index 3206e3f7b..bc0333083 100644 --- a/internal/translator/openai_gcpanthropic_stream_test.go +++ b/internal/translator/openai_gcpanthropic_stream_test.go @@ -539,7 +539,7 @@ event: content_block_start data: {"type": "content_block_start", "index": 0, "content_block": {"type": "thinking", "name": "web_searcher"}} event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "text": "Searching for information..."}} +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": "Searching for information..."}} event: content_block_stop data: {"type": "content_block_stop", "index": 0} @@ -564,6 +564,7 @@ data: {"type": "message_stop"} bodyStr := string(bm) var contentDeltas []string + var reasoningTexts []string var foundToolCallWithArgs bool var finalFinishReason openai.ChatCompletionChoicesFinishReason @@ -586,6 +587,11 @@ data: {"type": "message_stop"} if choice.Delta.Content != nil { contentDeltas = append(contentDeltas, *choice.Delta.Content) } + if choice.Delta.ReasoningContent != nil { + if choice.Delta.ReasoningContent.Text != "" { + reasoningTexts = append(reasoningTexts, choice.Delta.ReasoningContent.Text) + } + } if len(choice.Delta.ToolCalls) > 0 { toolCall := choice.Delta.ToolCalls[0] // Check if this is the tool chunk that contains the arguments. @@ -607,11 +613,155 @@ data: {"type": "message_stop"} } } - fullContent := strings.Join(contentDeltas, "") - assert.Contains(t, fullContent, "Searching for information...") + fullReasoning := strings.Join(reasoningTexts, "") + + assert.Contains(t, fullReasoning, "Searching for information...") require.True(t, foundToolCallWithArgs, "Did not find a tool call chunk with arguments to assert against") assert.Equal(t, openai.ChatCompletionChoicesFinishReasonToolCalls, finalFinishReason, "Final finish reason should be 'tool_calls'") }) + + t.Run("handles thinking delta stream with text only", func(t *testing.T) { + sseStream := ` +event: message_start +data: {"type": "message_start", "message": {"id": "msg_thinking_1", "type": "message", "role": "assistant", "usage": {"input_tokens": 20, "output_tokens": 1}}} + +event: content_block_start +data: {"type": "content_block_start", "index": 0, "content_block": {"type": "thinking"}} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": "Let me think about this problem step by step."}} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": " First, I need to understand the requirements."}} + +event: content_block_stop +data: {"type": "content_block_stop", "index": 0} + +event: message_delta +data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 15}} + +event: message_stop +data: {"type": "message_stop"} +` + openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + _, _, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) + require.NoError(t, err) + require.NotNil(t, bm) + bodyStr := string(bm) + + var reasoningTexts []string + var foundFinishReason bool + + lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") + for line := range lines { + if !strings.HasPrefix(line, "data: ") || strings.Contains(line, "[DONE]") { + continue + } + jsonBody := strings.TrimPrefix(line, "data: ") + + var chunk openai.ChatCompletionResponseChunk + err = json.Unmarshal([]byte(jsonBody), &chunk) + require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) + + if len(chunk.Choices) == 0 { + continue + } + choice := chunk.Choices[0] + if choice.Delta != nil && choice.Delta.ReasoningContent != nil { + if choice.Delta.ReasoningContent.Text != "" { + reasoningTexts = append(reasoningTexts, choice.Delta.ReasoningContent.Text) + } + } + if choice.FinishReason == openai.ChatCompletionChoicesFinishReasonStop { + foundFinishReason = true + } + } + + fullReasoning := strings.Join(reasoningTexts, "") + assert.Contains(t, fullReasoning, "Let me think about this problem step by step.") + assert.Contains(t, fullReasoning, " First, I need to understand the requirements.") + require.True(t, foundFinishReason, "Should find stop finish reason") + }) + + t.Run("handles thinking delta stream with text and signature", func(t *testing.T) { + sseStream := ` +event: message_start +data: {"type": "message_start", "message": {"id": "msg_thinking_2", "type": "message", "role": "assistant", "usage": {"input_tokens": 25, "output_tokens": 1}}} + +event: content_block_start +data: {"type": "content_block_start", "index": 0, "content_block": {"type": "thinking"}} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": "Processing request...", "signature": "sig_abc123"}} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": " Analyzing data...", "signature": "sig_def456"}} + +event: content_block_stop +data: {"type": "content_block_stop", "index": 0} + +event: message_delta +data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 20}} + +event: message_stop +data: {"type": "message_stop"} +` + openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + _, _, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) + require.NoError(t, err) + require.NotNil(t, bm) + bodyStr := string(bm) + + var reasoningTexts []string + var signatures []string + var foundFinishReason bool + + lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") + for line := range lines { + if !strings.HasPrefix(line, "data: ") || strings.Contains(line, "[DONE]") { + continue + } + jsonBody := strings.TrimPrefix(line, "data: ") + + var chunk openai.ChatCompletionResponseChunk + err = json.Unmarshal([]byte(jsonBody), &chunk) + require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) + + if len(chunk.Choices) == 0 { + continue + } + choice := chunk.Choices[0] + if choice.Delta != nil && choice.Delta.ReasoningContent != nil { + if choice.Delta.ReasoningContent.Text != "" { + reasoningTexts = append(reasoningTexts, choice.Delta.ReasoningContent.Text) + } + if choice.Delta.ReasoningContent.Signature != "" { + signatures = append(signatures, choice.Delta.ReasoningContent.Signature) + } + } + if choice.FinishReason == openai.ChatCompletionChoicesFinishReasonStop { + foundFinishReason = true + } + } + + fullReasoning := strings.Join(reasoningTexts, "") + assert.Contains(t, fullReasoning, "Processing request...") + assert.Contains(t, fullReasoning, " Analyzing data...") + + allSignatures := strings.Join(signatures, ",") + assert.Contains(t, allSignatures, "sig_abc123") + assert.Contains(t, allSignatures, "sig_def456") + + require.True(t, foundFinishReason, "Should find stop finish reason") + }) } func TestAnthropicStreamParser_EventTypes(t *testing.T) { diff --git a/internal/translator/openai_gcpanthropic_test.go b/internal/translator/openai_gcpanthropic_test.go index 740f9b188..395f3e790 100644 --- a/internal/translator/openai_gcpanthropic_test.go +++ b/internal/translator/openai_gcpanthropic_test.go @@ -26,6 +26,7 @@ import ( "github.com/tidwall/gjson" "k8s.io/utils/ptr" + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" ) @@ -485,6 +486,93 @@ func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseBody(t *testing. }, }, }, + { + name: "response with thinking content", + inputResponse: &anthropic.Message{ + ID: "msg_01XYZ456", + Model: "claude-3-5-sonnet-20241022", + Role: constant.Assistant(anthropic.MessageParamRoleAssistant), + Content: []anthropic.ContentBlockUnion{{Type: "thinking", Thinking: "Let me think about this...", Signature: "signature_123"}}, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{InputTokens: 15, OutputTokens: 25, CacheReadInputTokens: 3}, + }, + respHeaders: map[string]string{statusHeaderName: "200"}, + expectedOpenAIResponse: openai.ChatCompletionResponse{ + ID: "msg_01XYZ456", + Model: "claude-3-5-sonnet-20241022", + Created: openai.JSONUNIXTime(time.Unix(ReleaseDateUnix, 0)), + Object: "chat.completion", + Usage: openai.Usage{ + PromptTokens: 18, + CompletionTokens: 25, + TotalTokens: 43, + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: 3, + }, + }, + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{ + Role: "assistant", + ReasoningContent: &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + ReasoningText: &awsbedrock.ReasoningTextBlock{ + Text: "Let me think about this...", + Signature: "signature_123", + }, + }, + }, + }, + }, + FinishReason: openai.ChatCompletionChoicesFinishReasonStop, + }, + }, + }, + }, + { + name: "response with redacted thinking content", + inputResponse: &anthropic.Message{ + ID: "msg_01XYZ789", + Model: "claude-3-5-sonnet-20241022", + Role: constant.Assistant(anthropic.MessageParamRoleAssistant), + Content: []anthropic.ContentBlockUnion{{Type: "redacted_thinking", Data: "redacted_data_content"}}, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{InputTokens: 12, OutputTokens: 18, CacheReadInputTokens: 1}, + }, + respHeaders: map[string]string{statusHeaderName: "200"}, + expectedOpenAIResponse: openai.ChatCompletionResponse{ + ID: "msg_01XYZ789", + Model: "claude-3-5-sonnet-20241022", + Created: openai.JSONUNIXTime(time.Unix(ReleaseDateUnix, 0)), + Object: "chat.completion", + Usage: openai.Usage{ + PromptTokens: 13, + CompletionTokens: 18, + TotalTokens: 31, + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: 1, + }, + }, + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{ + Role: "assistant", + ReasoningContent: &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + RedactedContent: []byte("redacted_data_content"), + }, + }, + }, + }, + FinishReason: openai.ChatCompletionChoicesFinishReasonStop, + }, + }, + }, + }, } for _, tt := range tests { diff --git a/tests/bench/aigw.yaml b/tests/bench/aigw.yaml new file mode 100644 index 000000000..6bf8efa54 --- /dev/null +++ b/tests/bench/aigw.yaml @@ -0,0 +1,123 @@ +# Copyright Envoy AI Gateway Authors +# SPDX-License-Identifier: Apache-2.0 +# The full text of the Apache license is available in the LICENSE file at +# the root of the repo. + +--- +apiVersion: aigateway.envoyproxy.io/v1alpha1 +kind: MCPRoute +metadata: + name: mcp-route + namespace: default +spec: + parentRefs: + - name: aigw-run + kind: Gateway + group: gateway.networking.k8s.io + path: "/mcp" + backendRefs: + - name: test + kind: Backend + group: gateway.envoyproxy.io + path: "/mcp" +--- +################################################################################### +############################### Backend Definitions ############################### +################################################################################### +apiVersion: gateway.envoyproxy.io/v1alpha1 +kind: Backend +metadata: + name: test + namespace: default +spec: + endpoints: + - ip: + address: 127.0.0.1 + port: 8080 +--- +################################################################################### +############################### Gateway Definitions ############################### +################################################################################### +apiVersion: gateway.networking.k8s.io/v1 +kind: GatewayClass +metadata: + name: aigw-run +spec: + controllerName: gateway.envoyproxy.io/gatewayclass-controller +--- +apiVersion: gateway.networking.k8s.io/v1 +kind: Gateway +metadata: + name: aigw-run + namespace: default +spec: + gatewayClassName: aigw-run + listeners: + - name: http + protocol: HTTP + port: 1975 + infrastructure: + parametersRef: + group: gateway.envoyproxy.io + kind: EnvoyProxy + name: envoy-ai-gateway +--- +apiVersion: gateway.envoyproxy.io/v1alpha1 +kind: EnvoyProxy +metadata: + name: envoy-ai-gateway + namespace: default +spec: + logging: + level: + default: error + bootstrap: + type: Merge + value: |- + admin: + address: + socket_address: + address: 127.0.0.1 + port_value: 9901 + telemetry: + accessLog: + settings: + - sinks: + - type: File + file: + path: /dev/stdout + format: + type: JSON + json: + # MCP specific fields + mcp_request_id: "%DYNAMIC_METADATA(io.envoy.ai_gateway:mcp_request_id)%" + mcp_session_id: "%REQ(MCP-SESSION-ID)%" + mcp_method: "%DYNAMIC_METADATA(io.envoy.ai_gateway:mcp_method)%" + mcp_backend: "%DYNAMIC_METADATA(io.envoy.ai_gateway:mcp_backend)%" + # Default fields + start_time: "%START_TIME%" + method: "%REQ(:METHOD)%" + x-envoy-origin-path: "%REQ(X-ENVOY-ORIGINAL-PATH?:PATH)%" + protocol: "%PROTOCOL%" + response_code: "%RESPONSE_CODE%" + response_flags: "%RESPONSE_FLAGS%" + response_code_details: "%RESPONSE_CODE_DETAILS%" + connection_termination_details: "%CONNECTION_TERMINATION_DETAILS%" + upstream_transport_failure_reason: "%UPSTREAM_TRANSPORT_FAILURE_REASON%" + bytes_received: "%BYTES_RECEIVED%" + bytes_sent: "%BYTES_SENT%" + duration: "%DURATION%" + x-envoy-upstream-service-time: "%RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)%" + x-forwarded-for: "%REQ(X-FORWARDED-FOR)%" + user-agent: "%REQ(USER-AGENT)%" + x-request-id: "%REQ(X-REQUEST-ID)%" + ":authority": "%REQ(:AUTHORITY)%" + upstream_host: "%UPSTREAM_HOST%" + upstream_cluster: "%UPSTREAM_CLUSTER%" + upstream_local_address: "%UPSTREAM_LOCAL_ADDRESS%" + downstream_local_address: "%DOWNSTREAM_LOCAL_ADDRESS%" + downstream_remote_address: "%DOWNSTREAM_REMOTE_ADDRESS%" + requested_server_name: "%REQUESTED_SERVER_NAME%" + route_name: "%ROUTE_NAME%" +--- + diff --git a/tests/bench/bench_test.go b/tests/bench/bench_test.go new file mode 100644 index 000000000..61b45dfde --- /dev/null +++ b/tests/bench/bench_test.go @@ -0,0 +1,167 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +// 1. Build AIGW +// make clean build.aigw +// 2. Run the bench test +// go test -timeout=15m -run='^$' -bench=. ./tests/bench/... + +package bench + +import ( + "context" + _ "embed" + "fmt" + "net" + "net/url" + "os/exec" + "runtime" + "strings" + "syscall" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/tests/internal/testmcp" +) + +const ( + writeTimeout = 120 * time.Second + mcpServerPort = 8080 + aigwPort = 1975 +) + +var aigwBinary = fmt.Sprintf("../../out/aigw-%s-%s", runtime.GOOS, runtime.GOARCH) + +type MCPBenchCase struct { + Name string + ProxyBinary string + ProxyArgs []string + TestAddr string +} + +// setupBenchmark sets up the client connection. +func setupBenchmark(b *testing.B) []MCPBenchCase { + b.Helper() // Treat this as a helper function + + // setup MCP server + mcpServer := testmcp.NewServer(&testmcp.Options{ + Port: mcpServerPort, + ForceJSONResponse: false, + DumbEchoServer: true, + WriteTimeout: writeTimeout, + DisableLog: true, + }) + b.Cleanup(func() { + _ = mcpServer.Close() + }) + + return []MCPBenchCase{ + { + Name: "Baseline_NoProxy", + TestAddr: fmt.Sprintf("http://localhost:%d", mcpServerPort), + }, + { + Name: "EAIGW_Default", + TestAddr: fmt.Sprintf("http://localhost:%d/mcp", aigwPort), + ProxyBinary: aigwBinary, + ProxyArgs: []string{"run", "./aigw.yaml"}, + }, + { + Name: "EAIGW_Config_100", + TestAddr: fmt.Sprintf("http://localhost:%d/mcp", aigwPort), + ProxyBinary: aigwBinary, + ProxyArgs: []string{"run", "./aigw.yaml", "--mcp-session-encryption-iterations=100"}, + }, + { + Name: "EAIGW_Inline_100", + TestAddr: fmt.Sprintf("http://localhost:%d/mcp", aigwPort), + ProxyBinary: aigwBinary, + ProxyArgs: []string{ + "run", + "--mcp-session-encryption-iterations=100", + `--mcp-json={"mcpServers":{"aigw":{"type":"http","url":"http://localhost:8080/mcp"}}}`, + }, + }, + } +} + +func BenchmarkMCP(b *testing.B) { + cases := setupBenchmark(b) + for _, tc := range cases { + var proxy *exec.Cmd + if tc.ProxyBinary != "" { + proxy = startProxy(b, &tc) + } + + b.Run(tc.Name, func(b *testing.B) { + mcpClient := mcp.NewClient(&mcp.Implementation{Name: "bench-http-client", Version: "0.1.0"}, nil) + cs, err := mcpClient.Connect(b.Context(), &mcp.StreamableClientTransport{Endpoint: tc.TestAddr}, nil) + if err != nil { + b.Fatalf("Failed to connect server: %v", err) + } + + tools, err := cs.ListTools(b.Context(), &mcp.ListToolsParams{}) + if err != nil { + b.Fatalf("Failed to list tools: %v", err) + } + var toolName string + for _, t := range tools.Tools { + if strings.Contains(t.Name, "echo") { + toolName = t.Name + break + } + } + if toolName == "" { + b.Fatalf("no echo tool found") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithTimeout(b.Context(), 5*time.Second) + res, err := cs.CallTool(ctx, &mcp.CallToolParams{ + Name: toolName, + Arguments: testmcp.ToolEchoArgs{Text: "hello MCP"}, + }) + cancel() + if err != nil { + b.Fatalf("MCP Tool call name %s failed at iteration %d: %v", toolName, i, err) + } + + txt, ok := res.Content[0].(*mcp.TextContent) + if !ok { + b.Fatalf("unexpected content type") + } + if txt.Text != "dumb echo: hello MCP" { + b.Fatalf("unexpected text: %q", txt.Text) + } + } + }) + + if proxy != nil && proxy.Process != nil { + _ = syscall.Kill(-proxy.Process.Pid, syscall.SIGKILL) + } + } +} + +func startProxy(b testing.TB, tc *MCPBenchCase) *exec.Cmd { + addr, err := url.Parse(tc.TestAddr) + require.NoError(b, err) + + cmd := exec.CommandContext(b.Context(), tc.ProxyBinary, tc.ProxyArgs...) // nolint: gosec + // put into new process group so we can kill the entire process tree (and children) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + require.NoError(b, cmd.Start()) + + // Wait until we can connect to the proxy + require.Eventually(b, func() bool { + _, err = (&net.Dialer{}).DialContext(b.Context(), "tcp", addr.Host) + return err == nil + }, 30*time.Second, 500*time.Millisecond, "proxy %s did not become ready in time", tc.Name) + + return cmd +} diff --git a/tests/internal/testmcp/server.go b/tests/internal/testmcp/server.go index ed1e17138..c8f811205 100644 --- a/tests/internal/testmcp/server.go +++ b/tests/internal/testmcp/server.go @@ -26,6 +26,7 @@ type Options struct { Port int ForceJSONResponse, DumbEchoServer bool WriteTimeout time.Duration + DisableLog bool } // NewServer starts a demo MCP server with two tools: echo and sum. @@ -130,6 +131,9 @@ func NewServer(opts *Options) *http.Server { WriteTimeout: opts.WriteTimeout, Handler: handler, ConnState: func(conn net.Conn, state http.ConnState) { + if opts.DisableLog { + return + } log.Printf("MCP SERVER connection [%s] %s -> %s\n", state, conn.RemoteAddr(), conn.LocalAddr()) }, }