From edc9b6148445c27f7fca2b8729643fbb25b23ae9 Mon Sep 17 00:00:00 2001 From: yxia216 Date: Sat, 29 Nov 2025 17:37:53 -0500 Subject: [PATCH 1/2] init Signed-off-by: yxia216 --- internal/translator/openai_awsopenai.go | 58 +++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 internal/translator/openai_awsopenai.go diff --git a/internal/translator/openai_awsopenai.go b/internal/translator/openai_awsopenai.go new file mode 100644 index 000000000..741a16b5d --- /dev/null +++ b/internal/translator/openai_awsopenai.go @@ -0,0 +1,58 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "fmt" + "strconv" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/internalapi" +) + +// NewChatCompletionOpenAIToAwsOpenAITranslator implements [Factory] for OpenAI to Aws OpenAI translations. +func NewChatCompletionOpenAIToAwsOpenAITranslator(apiVersion string, modelNameOverride internalapi.ModelNameOverride) OpenAIChatCompletionTranslator { + return &openAIToAwsOpenAITranslatorV1ChatCompletion{ + apiVersion: apiVersion, + openAIToOpenAITranslatorV1ChatCompletion: openAIToOpenAITranslatorV1ChatCompletion{ + modelNameOverride: modelNameOverride, + }, + } +} + +// openAIToAwsOpenAITranslatorV1ChatCompletion adapts OpenAI requests for Aws OpenAI Service. +// Azure ignores the model field in the request body, using deployment name from the URI path instead: +// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/reference#chat-completions +type openAIToAwsOpenAITranslatorV1ChatCompletion struct { + openAIToOpenAITranslatorV1ChatCompletion +} + +func (o *openAIToAwsOpenAITranslatorV1ChatCompletion) RequestBody(raw []byte, req *openai.ChatCompletionRequest, forceBodyMutation bool) ( + newHeaders []internalapi.Header, newBody []byte, err error, +) { + modelName := req.Model + if o.modelNameOverride != "" { + // If modelName is set we override the model to be used for the request. + modelName = o.modelNameOverride + } + // Ensure the response includes a model. This is set to accommodate test or + // misimplemented backends. + o.requestModel = modelName + + // Azure OpenAI uses a {deployment-id} that may match the deployed model's name. + // We use the routed model as the deployment, stored in the path. + pathTemplate := "/openai/deployments/%s/chat/completions?api-version=%s" + newHeaders = []internalapi.Header{{pathHeaderName, fmt.Sprintf(pathTemplate, modelName, o.apiVersion)}} + if req.Stream { + o.stream = true + } + + // On retry, the path might have changed to a different provider. So, this will enesure that the path is always set to OpenAI. + if forceBodyMutation { + newHeaders = append(newHeaders, internalapi.Header{contentLengthHeaderName, strconv.Itoa(len(raw))}) + } + return +} From 8241f894b783bfb48a62a436ea57678b07870f88 Mon Sep 17 00:00:00 2001 From: yxia216 Date: Fri, 5 Dec 2025 10:33:32 -0500 Subject: [PATCH 2/2] init Signed-off-by: yxia216 --- api/v1alpha1/shared_types.go | 7 +- internal/endpointspec/endpointspec.go | 2 + internal/endpointspec/endpointspec_test.go | 31 + internal/filterapi/filterconfig.go | 3 + internal/translator/openai_awsopenai.go | 258 +++++++- internal/translator/openai_awsopenai_test.go | 637 +++++++++++++++++++ internal/translator/translator.go | 5 +- 7 files changed, 920 insertions(+), 23 deletions(-) create mode 100644 internal/translator/openai_awsopenai_test.go diff --git a/api/v1alpha1/shared_types.go b/api/v1alpha1/shared_types.go index 628f34f18..2bfe280f9 100644 --- a/api/v1alpha1/shared_types.go +++ b/api/v1alpha1/shared_types.go @@ -15,7 +15,7 @@ package v1alpha1 type VersionedAPISchema struct { // Name is the name of the API schema of the AIGatewayRoute or AIServiceBackend. // - // +kubebuilder:validation:Enum=OpenAI;Cohere;AWSBedrock;AzureOpenAI;GCPVertexAI;GCPAnthropic;Anthropic;AWSAnthropic + // +kubebuilder:validation:Enum=OpenAI;Cohere;AWSBedrock;AzureOpenAI;GCPVertexAI;GCPAnthropic;Anthropic;AWSAnthropic;AWSOpenAI Name APISchema `json:"name"` // Version is the version of the API schema. @@ -75,6 +75,11 @@ const ( // https://aws.amazon.com/bedrock/anthropic/ // https://docs.claude.com/en/api/claude-on-amazon-bedrock APISchemaAWSAnthropic APISchema = "AWSAnthropic" + // APISchemaAWSOpenAI is the schema for OpenAI models hosted on AWS Bedrock. + // Uses the AWS Bedrock InvokeModel API with OpenAI format for requests and responses. + // + // https://aws.amazon.com/bedrock/ + APISchemaAWSOpenAI APISchema = "AWSOpenAI" ) const ( diff --git a/internal/endpointspec/endpointspec.go b/internal/endpointspec/endpointspec.go index 44e5fc150..c0d62bfd2 100644 --- a/internal/endpointspec/endpointspec.go +++ b/internal/endpointspec/endpointspec.go @@ -113,6 +113,8 @@ func (ChatCompletionsEndpointSpec) GetTranslator(schema filterapi.VersionedAPISc return translator.NewChatCompletionOpenAIToOpenAITranslator(schema.Version, modelNameOverride), nil case filterapi.APISchemaAWSBedrock: return translator.NewChatCompletionOpenAIToAWSBedrockTranslator(modelNameOverride), nil + case filterapi.APISchemaAWSOpenAI: + return translator.NewChatCompletionOpenAIToAwsOpenAITranslator(schema.Version, modelNameOverride), nil case filterapi.APISchemaAzureOpenAI: return translator.NewChatCompletionOpenAIToAzureOpenAITranslator(schema.Version, modelNameOverride), nil case filterapi.APISchemaGCPVertexAI: diff --git a/internal/endpointspec/endpointspec_test.go b/internal/endpointspec/endpointspec_test.go index f3ecced2c..c9a38087b 100644 --- a/internal/endpointspec/endpointspec_test.go +++ b/internal/endpointspec/endpointspec_test.go @@ -79,6 +79,7 @@ func TestChatCompletionsEndpointSpec_GetTranslator(t *testing.T) { supported := []filterapi.VersionedAPISchema{ {Name: filterapi.APISchemaOpenAI, Version: "v1"}, {Name: filterapi.APISchemaAWSBedrock}, + {Name: filterapi.APISchemaAWSOpenAI, Version: "v1"}, {Name: filterapi.APISchemaAzureOpenAI, Version: "2024-02-01"}, {Name: filterapi.APISchemaGCPVertexAI}, {Name: filterapi.APISchemaGCPAnthropic, Version: "2024-05-01"}, @@ -273,3 +274,33 @@ func TestRerankEndpointSpec_GetTranslator(t *testing.T) { _, err = spec.GetTranslator(filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}, "override") require.ErrorContains(t, err, "unsupported API schema") } + +func TestAWSOpenAIAPISchemaIntegration(t *testing.T) { + schema := filterapi.VersionedAPISchema{ + Name: filterapi.APISchemaAWSOpenAI, + Version: "v1", + } + + t.Run("ChatCompletions", func(t *testing.T) { + endpointSpec := ChatCompletionsEndpointSpec{} + translator, err := endpointSpec.GetTranslator(schema, "") + require.NoError(t, err) + require.NotNil(t, translator) + }) +} + +func TestAWSOpenAIAPISchemaWithModelOverride(t *testing.T) { + schema := filterapi.VersionedAPISchema{ + Name: filterapi.APISchemaAWSOpenAI, + Version: "v1", + } + + modelOverride := "arn:aws:bedrock:us-east-1:123456789:model/gpt-4" + + t.Run("ChatCompletions", func(t *testing.T) { + endpointSpec := ChatCompletionsEndpointSpec{} + translator, err := endpointSpec.GetTranslator(schema, modelOverride) + require.NoError(t, err) + require.NotNil(t, translator) + }) +} diff --git a/internal/filterapi/filterconfig.go b/internal/filterapi/filterconfig.go index b0721e470..20479009c 100644 --- a/internal/filterapi/filterconfig.go +++ b/internal/filterapi/filterconfig.go @@ -114,6 +114,9 @@ const ( // APISchemaAWSAnthropic represents the AWS Bedrock Anthropic API schema. // Used for Claude models hosted on AWS Bedrock using the native Anthropic Messages API. APISchemaAWSAnthropic APISchemaName = "AWSAnthropic" + // APISchemaAWSOpenAI represents the AWS Bedrock OpenAI API schema. + // Used for gpt models hosted on AWS Bedrock using the OpenAI API. + APISchemaAWSOpenAI APISchemaName = "AWSOpenAI" ) // RouteRuleName is the name of the route rule. diff --git a/internal/translator/openai_awsopenai.go b/internal/translator/openai_awsopenai.go index 741a16b5d..9a4dcaf70 100644 --- a/internal/translator/openai_awsopenai.go +++ b/internal/translator/openai_awsopenai.go @@ -6,53 +6,271 @@ package translator import ( + "encoding/json" "fmt" + "io" + "net/url" "strconv" + "strings" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" "github.com/envoyproxy/ai-gateway/internal/internalapi" + "github.com/envoyproxy/ai-gateway/internal/metrics" + tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" ) // NewChatCompletionOpenAIToAwsOpenAITranslator implements [Factory] for OpenAI to Aws OpenAI translations. func NewChatCompletionOpenAIToAwsOpenAITranslator(apiVersion string, modelNameOverride internalapi.ModelNameOverride) OpenAIChatCompletionTranslator { return &openAIToAwsOpenAITranslatorV1ChatCompletion{ - apiVersion: apiVersion, - openAIToOpenAITranslatorV1ChatCompletion: openAIToOpenAITranslatorV1ChatCompletion{ - modelNameOverride: modelNameOverride, - }, + modelNameOverride: modelNameOverride, } } -// openAIToAwsOpenAITranslatorV1ChatCompletion adapts OpenAI requests for Aws OpenAI Service. -// Azure ignores the model field in the request body, using deployment name from the URI path instead: -// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/reference#chat-completions +// openAIToAwsOpenAITranslatorV1ChatCompletion adapts OpenAI requests for AWS Bedrock InvokeModel API. +// This uses the InvokeModel API which accepts model-specific request/response formats. +// For OpenAI models, this preserves the OpenAI format but uses AWS Bedrock endpoints. type openAIToAwsOpenAITranslatorV1ChatCompletion struct { openAIToOpenAITranslatorV1ChatCompletion + modelNameOverride internalapi.ModelNameOverride + requestModel internalapi.RequestModel + responseID string + stream bool } func (o *openAIToAwsOpenAITranslatorV1ChatCompletion) RequestBody(raw []byte, req *openai.ChatCompletionRequest, forceBodyMutation bool) ( newHeaders []internalapi.Header, newBody []byte, err error, ) { - modelName := req.Model + // Store request model and streaming state + o.requestModel = req.Model if o.modelNameOverride != "" { - // If modelName is set we override the model to be used for the request. - modelName = o.modelNameOverride + o.requestModel = o.modelNameOverride } - // Ensure the response includes a model. This is set to accommodate test or - // misimplemented backends. - o.requestModel = modelName - // Azure OpenAI uses a {deployment-id} that may match the deployed model's name. - // We use the routed model as the deployment, stored in the path. - pathTemplate := "/openai/deployments/%s/chat/completions?api-version=%s" - newHeaders = []internalapi.Header{{pathHeaderName, fmt.Sprintf(pathTemplate, modelName, o.apiVersion)}} if req.Stream { o.stream = true } - // On retry, the path might have changed to a different provider. So, this will enesure that the path is always set to OpenAI. - if forceBodyMutation { - newHeaders = append(newHeaders, internalapi.Header{contentLengthHeaderName, strconv.Itoa(len(raw))}) + // URL encode the model name for the path to handle special characters (e.g., ARNs) + encodedModelName := url.PathEscape(o.requestModel) + + // Set the path for AWS Bedrock InvokeModel API + pathTemplate := "/model/%s/invoke" + if req.Stream { + pathTemplate = "/model/%s/invoke-with-response-stream" + } + + // For InvokeModel API, the request body should be the OpenAI format + // since we're invoking OpenAI models through Bedrock + if o.modelNameOverride != "" { + // If we need to override the model in the request body + var openAIReq openai.ChatCompletionRequest + if err := json.Unmarshal(raw, &openAIReq); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal request: %w", err) + } + openAIReq.Model = o.modelNameOverride + newBody, err = json.Marshal(openAIReq) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal request: %w", err) + } + } else { + newBody = raw + } + + newHeaders = []internalapi.Header{ + {pathHeaderName, fmt.Sprintf(pathTemplate, encodedModelName)}, + {contentLengthHeaderName, strconv.Itoa(len(newBody))}, + } + + return +} + +// ResponseHeaders implements [OpenAIChatCompletionTranslator.ResponseHeaders]. +func (o *openAIToAwsOpenAITranslatorV1ChatCompletion) ResponseHeaders(headers map[string]string) ( + newHeaders []internalapi.Header, err error, +) { + // Store the response ID for tracking + o.responseID = headers["x-amzn-requestid"] + + // For streaming responses, ensure content-type is correctly set + if o.stream { + contentType := headers["content-type"] + // AWS Bedrock might return different content-type for streaming + if contentType == "application/vnd.amazon.eventstream" { + // Convert to the expected streaming content-type + newHeaders = []internalapi.Header{{contentTypeHeaderName, "text/event-stream"}} + } + } + return +} + +// ResponseBody implements [OpenAIChatCompletionTranslator.ResponseBody]. +// AWS Bedrock InvokeModel API with OpenAI models returns responses in OpenAI format. +// This function handles both streaming and non-streaming responses. +func (o *openAIToAwsOpenAITranslatorV1ChatCompletion) ResponseBody(headers map[string]string, body io.Reader, endOfStream bool, span tracing.ChatCompletionSpan) ( + newHeaders []internalapi.Header, newBody []byte, tokenUsage metrics.TokenUsage, responseModel string, err error, +) { + responseModel = o.requestModel + + if o.stream { + // Handle streaming response + var buf []byte + buf, err = io.ReadAll(body) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to read streaming body: %w", err) + } + + // For InvokeModel with OpenAI models, the streaming response should already be in + // Server-Sent Events format with OpenAI chunks + newBody = buf + + // Parse for token usage if available in the stream + for _, line := range strings.Split(string(buf), "\n") { + if dataStr, found := strings.CutPrefix(line, "data: "); found { + if dataStr != "[DONE]" { + var chunk openai.ChatCompletionResponseChunk + if json.Unmarshal([]byte(dataStr), &chunk) == nil { + if chunk.Usage != nil { + tokenUsage.SetInputTokens(uint32(chunk.Usage.PromptTokens)) + tokenUsage.SetOutputTokens(uint32(chunk.Usage.CompletionTokens)) + tokenUsage.SetTotalTokens(uint32(chunk.Usage.TotalTokens)) + } + if span != nil { + span.RecordResponseChunk(&chunk) + } + } + } + } + } + + if endOfStream && !strings.HasSuffix(string(newBody), "data: [DONE]\n") { + newBody = append(newBody, []byte("data: [DONE]\n")...) + } + } else { + // Handle non-streaming response + var buf []byte + buf, err = io.ReadAll(body) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to read body: %w", err) + } + + // For InvokeModel with OpenAI models, response should already be in OpenAI format + var openAIResp openai.ChatCompletionResponse + if err = json.Unmarshal(buf, &openAIResp); err != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Use response model if available, otherwise use request model + if openAIResp.Model != "" { + responseModel = openAIResp.Model + } + + // Extract token usage + if openAIResp.Usage.TotalTokens > 0 { + tokenUsage.SetInputTokens(uint32(openAIResp.Usage.PromptTokens)) + tokenUsage.SetOutputTokens(uint32(openAIResp.Usage.CompletionTokens)) + tokenUsage.SetTotalTokens(uint32(openAIResp.Usage.TotalTokens)) + } + + // Override the ID with AWS request ID if available + if o.responseID != "" { + openAIResp.ID = o.responseID + } + + newBody, err = json.Marshal(openAIResp) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to marshal response: %w", err) + } + + if span != nil { + span.RecordResponse(&openAIResp) + } + } + + if len(newBody) > 0 { + newHeaders = []internalapi.Header{{contentLengthHeaderName, strconv.Itoa(len(newBody))}} + } + return +} + +// ResponseError implements [OpenAIChatCompletionTranslator.ResponseError]. +// Translates AWS Bedrock InvokeModel exceptions to OpenAI error format. +// The error type is typically stored in the "x-amzn-errortype" HTTP header for AWS error responses. +func (o *openAIToAwsOpenAITranslatorV1ChatCompletion) ResponseError(respHeaders map[string]string, body io.Reader) ( + newHeaders []internalapi.Header, newBody []byte, err error, +) { + statusCode := respHeaders[statusHeaderName] + var openaiError openai.Error + + // Check if we have a JSON error response + if v, ok := respHeaders[contentTypeHeaderName]; ok && strings.Contains(v, jsonContentType) { + // Try to parse as AWS Bedrock error + var buf []byte + buf, err = io.ReadAll(body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read error body: %w", err) + } + + // Check if it's already an OpenAI error format + var existingOpenAIError openai.Error + if json.Unmarshal(buf, &existingOpenAIError) == nil && existingOpenAIError.Error.Message != "" { + // Already in OpenAI format, return as-is + newBody = buf + } else { + // Try to parse as AWS error and convert to OpenAI format + var awsError struct { + Type string `json:"__type,omitempty"` + Message string `json:"message"` + Code string `json:"code,omitempty"` + } + if json.Unmarshal(buf, &awsError) == nil && awsError.Message != "" { + openaiError = openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: respHeaders[awsErrorTypeHeaderName], + Message: awsError.Message, + Code: &statusCode, + }, + } + } else { + // Generic AWS error format + openaiError = openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: awsInvokeModelBackendError, + Message: string(buf), + Code: &statusCode, + }, + } + } + newBody, err = json.Marshal(openaiError) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal error body: %w", err) + } + } + } else { + // Non-JSON error response + var buf []byte + buf, err = io.ReadAll(body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read error body: %w", err) + } + openaiError = openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: awsInvokeModelBackendError, + Message: string(buf), + Code: &statusCode, + }, + } + newBody, err = json.Marshal(openaiError) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal error body: %w", err) + } + } + + newHeaders = []internalapi.Header{ + {contentTypeHeaderName, jsonContentType}, + {contentLengthHeaderName, strconv.Itoa(len(newBody))}, } return } diff --git a/internal/translator/openai_awsopenai_test.go b/internal/translator/openai_awsopenai_test.go new file mode 100644 index 000000000..dba0ac1d2 --- /dev/null +++ b/internal/translator/openai_awsopenai_test.go @@ -0,0 +1,637 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "encoding/json" + "strconv" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + "k8s.io/utils/ptr" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/internalapi" +) + +func TestOpenAIToAwsOpenAITranslatorV1ChatCompletion_RequestBody(t *testing.T) { + tests := []struct { + name string + input openai.ChatCompletionRequest + modelNameOverride internalapi.ModelNameOverride + expectedPath string + expectedBody openai.ChatCompletionRequest + }{ + { + name: "basic non-streaming request", + input: openai.ChatCompletionRequest{ + Stream: false, + Model: "gpt-4", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "Hello, world!", + }, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + }, + expectedPath: "/model/gpt-4/invoke", + expectedBody: openai.ChatCompletionRequest{ + Stream: false, + Model: "gpt-4", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "Hello, world!", + }, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + }, + }, + { + name: "streaming request", + input: openai.ChatCompletionRequest{ + Stream: true, + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "Tell me a story", + }, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + }, + expectedPath: "/model/gpt-3.5-turbo/invoke-with-response-stream", + expectedBody: openai.ChatCompletionRequest{ + Stream: true, + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "Tell me a story", + }, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + }, + }, + { + name: "model name override", + input: openai.ChatCompletionRequest{ + Stream: false, + Model: "gpt-4", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "Hello with override", + }, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + }, + modelNameOverride: "arn:aws:bedrock:us-east-1:123456789:model/gpt-4", + expectedPath: "/model/arn:aws:bedrock:us-east-1:123456789:model%2Fgpt-4/invoke", + expectedBody: openai.ChatCompletionRequest{ + Stream: false, + Model: "arn:aws:bedrock:us-east-1:123456789:model/gpt-4", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "Hello with override", + }, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + }, + }, + { + name: "complex request with tools", + input: openai.ChatCompletionRequest{ + Stream: false, + Model: "gpt-4", + Temperature: ptr.To(0.7), + MaxTokens: ptr.To(int64(1000)), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfSystem: &openai.ChatCompletionSystemMessageParam{ + Content: openai.ContentUnion{ + Value: "You are a helpful assistant.", + }, + Role: openai.ChatMessageRoleSystem, + }, + }, + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "What's the weather?", + }, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get current weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "City name", + }, + }, + }, + }, + }, + }, + }, + expectedPath: "/model/gpt-4/invoke", + expectedBody: openai.ChatCompletionRequest{ + Stream: false, + Model: "gpt-4", + Temperature: ptr.To(0.7), + MaxTokens: ptr.To(int64(1000)), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfSystem: &openai.ChatCompletionSystemMessageParam{ + Content: openai.ContentUnion{ + Value: "You are a helpful assistant.", + }, + Role: openai.ChatMessageRoleSystem, + }, + }, + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "What's the weather?", + }, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get current weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "City name", + }, + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + translator := NewChatCompletionOpenAIToAwsOpenAITranslator("v1", tt.modelNameOverride) + + // Marshal input to JSON + rawBody, err := json.Marshal(tt.input) + require.NoError(t, err) + + // Call RequestBody + headers, newBody, err := translator.RequestBody(rawBody, &tt.input, false) + require.NoError(t, err) + + // Check headers + require.Len(t, headers, 2) + require.Equal(t, pathHeaderName, headers[0][0]) + require.Equal(t, tt.expectedPath, headers[0][1]) + require.Equal(t, contentLengthHeaderName, headers[1][0]) + require.Equal(t, strconv.Itoa(len(newBody)), headers[1][1]) + + // Check body - compare essential fields instead of full struct comparison + var actualBody openai.ChatCompletionRequest + err = json.Unmarshal(newBody, &actualBody) + require.NoError(t, err) + + // Compare essential fields only + require.Equal(t, tt.expectedBody.Model, actualBody.Model) + require.Equal(t, tt.expectedBody.Stream, actualBody.Stream) + require.Equal(t, len(tt.expectedBody.Messages), len(actualBody.Messages)) + + // For complex requests, check tools and parameters + if tt.expectedBody.Temperature != nil { + require.Equal(t, *tt.expectedBody.Temperature, *actualBody.Temperature) + } + if tt.expectedBody.MaxTokens != nil { + require.Equal(t, *tt.expectedBody.MaxTokens, *actualBody.MaxTokens) + } + if len(tt.expectedBody.Tools) > 0 { + require.Equal(t, len(tt.expectedBody.Tools), len(actualBody.Tools)) + require.Equal(t, tt.expectedBody.Tools[0].Type, actualBody.Tools[0].Type) + if tt.expectedBody.Tools[0].Function != nil { + require.Equal(t, tt.expectedBody.Tools[0].Function.Name, actualBody.Tools[0].Function.Name) + } + } + }) + } +} + +func TestOpenAIToAwsOpenAITranslatorV1ChatCompletion_ResponseHeaders(t *testing.T) { + tests := []struct { + name string + inputHeaders map[string]string + isStreaming bool + expectedHeaders []internalapi.Header + }{ + { + name: "non-streaming response", + inputHeaders: map[string]string{ + "x-amzn-requestid": "test-request-id", + "content-type": "application/json", + }, + isStreaming: false, + expectedHeaders: nil, + }, + { + name: "streaming response with eventstream", + inputHeaders: map[string]string{ + "x-amzn-requestid": "test-request-id", + "content-type": "application/vnd.amazon.eventstream", + }, + isStreaming: true, + expectedHeaders: []internalapi.Header{ + {contentTypeHeaderName, "text/event-stream"}, + }, + }, + { + name: "streaming response with correct content-type", + inputHeaders: map[string]string{ + "x-amzn-requestid": "test-request-id", + "content-type": "text/event-stream", + }, + isStreaming: true, + expectedHeaders: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + translator := NewChatCompletionOpenAIToAwsOpenAITranslator("v1", "").(*openAIToAwsOpenAITranslatorV1ChatCompletion) + translator.stream = tt.isStreaming + + headers, err := translator.ResponseHeaders(tt.inputHeaders) + require.NoError(t, err) + + if diff := cmp.Diff(tt.expectedHeaders, headers); diff != "" { + t.Errorf("ResponseHeaders() mismatch (-expected +actual):\n%s", diff) + } + + // Verify responseID is stored + require.Equal(t, "test-request-id", translator.responseID) + }) + } +} + +func TestOpenAIToAwsOpenAITranslatorV1ChatCompletion_ResponseBody(t *testing.T) { + tests := []struct { + name string + isStreaming bool + inputBody string + endOfStream bool + expectedInputTokens uint32 + expectedOutputTokens uint32 + expectedTotalTokens uint32 + expectedResponseBody string + checkResponseBodyExact bool + }{ + { + name: "non-streaming response", + isStreaming: false, + inputBody: `{ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 13, + "completion_tokens": 9, + "total_tokens": 22 + } + }`, + expectedInputTokens: 13, + expectedOutputTokens: 9, + expectedTotalTokens: 22, + expectedResponseBody: `{ + "id": "test-request-id", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 13, + "completion_tokens": 9, + "total_tokens": 22 + } + }`, + }, + { + name: "streaming response", + isStreaming: true, + inputBody: `data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1677858242,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + +data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1677858242,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]} + +data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1677858242,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":13,"completion_tokens":9,"total_tokens":22}} + +`, + endOfStream: true, + expectedInputTokens: 13, + expectedOutputTokens: 9, + expectedTotalTokens: 22, + expectedResponseBody: `data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1677858242,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + +data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1677858242,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]} + +data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1677858242,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":13,"completion_tokens":9,"total_tokens":22}} + +data: [DONE] +`, + checkResponseBodyExact: true, + }, + { + name: "streaming response without DONE", + isStreaming: true, + inputBody: `data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1677858242,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + +`, + endOfStream: true, + expectedResponseBody: `data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1677858242,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + +data: [DONE] +`, + checkResponseBodyExact: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + translator := NewChatCompletionOpenAIToAwsOpenAITranslator("v1", "").(*openAIToAwsOpenAITranslatorV1ChatCompletion) + translator.stream = tt.isStreaming + translator.requestModel = "gpt-4" + translator.responseID = "test-request-id" + + headers := map[string]string{} + body := strings.NewReader(tt.inputBody) + + resultHeaders, resultBody, tokenUsage, responseModel, err := translator.ResponseBody(headers, body, tt.endOfStream, nil) + require.NoError(t, err) + + // Check token usage + actualInputTokens, _ := tokenUsage.InputTokens() + actualOutputTokens, _ := tokenUsage.OutputTokens() + actualTotalTokens, _ := tokenUsage.TotalTokens() + require.Equal(t, tt.expectedInputTokens, actualInputTokens) + require.Equal(t, tt.expectedOutputTokens, actualOutputTokens) + require.Equal(t, tt.expectedTotalTokens, actualTotalTokens) + + // Check response model + require.Equal(t, "gpt-4", responseModel) + + // Check headers + if len(resultBody) > 0 { + require.Len(t, resultHeaders, 1) + require.Equal(t, contentLengthHeaderName, resultHeaders[0][0]) + require.Equal(t, strconv.Itoa(len(resultBody)), resultHeaders[0][1]) + } + + // Check response body + if !tt.isStreaming { + // For non-streaming, parse and compare JSON + var expected, actual openai.ChatCompletionResponse + err = json.Unmarshal([]byte(tt.expectedResponseBody), &expected) + require.NoError(t, err) + err = json.Unmarshal(resultBody, &actual) + require.NoError(t, err) + + if diff := cmp.Diff(expected, actual); diff != "" { + t.Errorf("ResponseBody() body mismatch (-expected +actual):\n%s", diff) + } + } else if tt.checkResponseBodyExact { + // For streaming, compare the content directly + require.Equal(t, tt.expectedResponseBody, string(resultBody)) + } + }) + } +} + +func TestOpenAIToAwsOpenAITranslatorV1ChatCompletion_ResponseError(t *testing.T) { + tests := []struct { + name string + inputHeaders map[string]string + inputBody string + expectedError openai.Error + expectedHeaders []internalapi.Header + }{ + { + name: "AWS JSON error", + inputHeaders: map[string]string{ + statusHeaderName: "400", + contentTypeHeaderName: "application/json", + awsErrorTypeHeaderName: "ValidationException", + }, + inputBody: `{ + "__type": "ValidationException", + "message": "Invalid model specified" + }`, + expectedError: openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: "ValidationException", + Message: "Invalid model specified", + Code: ptr.To("400"), + }, + }, + }, + { + name: "existing OpenAI error format", + inputHeaders: map[string]string{ + statusHeaderName: "429", + contentTypeHeaderName: "application/json", + }, + inputBody: `{ + "error": { + "message": "Rate limit exceeded", + "type": "rate_limit_exceeded", + "code": "rate_limit_exceeded" + } + }`, + expectedError: openai.Error{ + Error: openai.ErrorType{ + Message: "Rate limit exceeded", + Type: "rate_limit_exceeded", + Code: ptr.To("rate_limit_exceeded"), + }, + }, + }, + { + name: "generic AWS error", + inputHeaders: map[string]string{ + statusHeaderName: "500", + contentTypeHeaderName: "application/json", + }, + inputBody: `{ + "message": "Internal server error" + }`, + expectedError: openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: "", // No error type header provided, so it will be empty + Message: "Internal server error", + Code: ptr.To("500"), + }, + }, + }, + { + name: "non-JSON error", + inputHeaders: map[string]string{ + statusHeaderName: "503", + contentTypeHeaderName: "text/plain", + }, + inputBody: "Service Unavailable", + expectedError: openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: awsInvokeModelBackendError, + Message: "Service Unavailable", + Code: ptr.To("503"), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + translator := NewChatCompletionOpenAIToAwsOpenAITranslator("v1", "").(*openAIToAwsOpenAITranslatorV1ChatCompletion) + + body := strings.NewReader(tt.inputBody) + + resultHeaders, resultBody, err := translator.ResponseError(tt.inputHeaders, body) + require.NoError(t, err) + + // Check headers + require.Len(t, resultHeaders, 2) + require.Equal(t, contentTypeHeaderName, resultHeaders[0][0]) + require.Equal(t, jsonContentType, resultHeaders[0][1]) + require.Equal(t, contentLengthHeaderName, resultHeaders[1][0]) + require.Equal(t, strconv.Itoa(len(resultBody)), resultHeaders[1][1]) + + // Parse and check error response + var actualError openai.Error + err = json.Unmarshal(resultBody, &actualError) + require.NoError(t, err) + + if diff := cmp.Diff(tt.expectedError, actualError); diff != "" { + t.Errorf("ResponseError() mismatch (-expected +actual):\n%s", diff) + } + }) + } +} + +func TestOpenAIToAwsOpenAITranslatorV1ChatCompletion_ModelNameEncoding(t *testing.T) { + tests := []struct { + name string + modelName string + expectedPath string + }{ + { + name: "simple model name", + modelName: "gpt-4", + expectedPath: "/model/gpt-4/invoke", + }, + { + name: "ARN with special characters", + modelName: "arn:aws:bedrock:us-east-1:123456789:model/gpt-4", + expectedPath: "/model/arn:aws:bedrock:us-east-1:123456789:model%2Fgpt-4/invoke", + }, + { + name: "model name with spaces", + modelName: "my custom model", + expectedPath: "/model/my%20custom%20model/invoke", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + translator := NewChatCompletionOpenAIToAwsOpenAITranslator("v1", tt.modelName) + + input := openai.ChatCompletionRequest{ + Stream: false, + Model: "original-model", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "test", + }, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + } + + rawBody, err := json.Marshal(input) + require.NoError(t, err) + + headers, _, err := translator.RequestBody(rawBody, &input, false) + require.NoError(t, err) + + require.Len(t, headers, 2) + require.Equal(t, pathHeaderName, headers[0][0]) + require.Equal(t, tt.expectedPath, headers[0][1]) + }) + } +} \ No newline at end of file diff --git a/internal/translator/translator.go b/internal/translator/translator.go index 6aef0857e..e5bd2ce6b 100644 --- a/internal/translator/translator.go +++ b/internal/translator/translator.go @@ -27,8 +27,9 @@ const ( awsErrorTypeHeaderName = "x-amzn-errortype" jsonContentType = "application/json" eventStreamContentType = "text/event-stream" - openAIBackendError = "OpenAIBackendError" - awsBedrockBackendError = "AWSBedrockBackendError" + openAIBackendError = "OpenAIBackendError" + awsBedrockBackendError = "AWSBedrockBackendError" + awsInvokeModelBackendError = "AWSInvokeModelBackendError" ) // Translator translates the request and response messages between the client