diff --git a/.gitignore b/.gitignore index 8616e2e67a..a9470ff28d 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,4 @@ inference-extension-conformance-test-report.yaml .mcp.json .goose +/aigw diff --git a/cmd/extproc/mainlib/main.go b/cmd/extproc/mainlib/main.go index e0fb6ae1bb..2a137621d5 100644 --- a/cmd/extproc/mainlib/main.go +++ b/cmd/extproc/mainlib/main.go @@ -234,6 +234,7 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) { messagesMetrics := metrics.NewMessagesFactory(meter, metricsRequestHeaderAttributes) completionMetrics := metrics.NewCompletionFactory(meter, metricsRequestHeaderAttributes) embeddingsMetrics := metrics.NewEmbeddingsFactory(meter, metricsRequestHeaderAttributes) + imageGenerationMetrics := metrics.NewImageGenerationFactory(meter, metricsRequestHeaderAttributes)() mcpMetrics := metrics.NewMCP(meter, metricsRequestHeaderAttributes) tracing, err := tracing.NewTracingFromEnv(ctx, os.Stdout, spanRequestHeaderAttributes) @@ -248,6 +249,7 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) { server.Register(path.Join(flags.rootPrefix, "/v1/chat/completions"), extproc.ChatCompletionProcessorFactory(chatCompletionMetrics)) server.Register(path.Join(flags.rootPrefix, "/v1/completions"), extproc.CompletionsProcessorFactory(completionMetrics)) server.Register(path.Join(flags.rootPrefix, "/v1/embeddings"), extproc.EmbeddingsProcessorFactory(embeddingsMetrics)) + server.Register(path.Join(flags.rootPrefix, "/v1/images/generations"), extproc.ImageGenerationProcessorFactory(imageGenerationMetrics)) server.Register(path.Join(flags.rootPrefix, "/v1/models"), extproc.NewModelsProcessor) server.Register(path.Join(flags.rootPrefix, "/anthropic/v1/messages"), extproc.MessagesProcessorFactory(messagesMetrics)) diff --git a/go.mod b/go.mod index 33faacea13..bcb50aaf2c 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0 github.com/a8m/envsubst v1.4.3 github.com/alecthomas/kong v1.12.1 + github.com/andybalholm/brotli v1.2.0 github.com/anthropics/anthropic-sdk-go v1.14.0 github.com/aws/aws-sdk-go-v2 v1.39.3 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 diff --git a/go.sum b/go.sum index 2d763865e7..8d459411b9 100644 --- a/go.sum +++ b/go.sum @@ -41,6 +41,8 @@ github.com/alecthomas/kong v1.12.1 h1:iq6aMJDcFYP9uFrLdsiZQ2ZMmcshduyGv4Pek0MQPW github.com/alecthomas/kong v1.12.1/go.mod h1:p2vqieVMeTAnaC83txKtXe8FLke2X07aruPWXyMPQrU= github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/anthropics/anthropic-sdk-go v1.14.0 h1:EzNQvnZlaDHe2UPkoUySDz3ixRgNbwKdH8KtFpv7pi4= github.com/anthropics/anthropic-sdk-go v1.14.0/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= @@ -448,6 +450,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510 h1:S2dVYn90KE98chqDkyE9Z4N61UnQd+KOfgp5Iu53llk= github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go index fa6116b2b6..66c1d10ae8 100644 --- a/internal/apischema/openai/openai.go +++ b/internal/apischema/openai/openai.go @@ -57,6 +57,10 @@ const ( // ModelTextEmbedding3Small is the cheapest model usable with /embeddings. ModelTextEmbedding3Small = "text-embedding-3-small" + + // ModelGPTImage1Mini is the smallest/cheapest Images model usable with + // /v1/images/generations. Use with size "1024x1024" and quality "low". + ModelGPTImage1Mini = "gpt-image-1-mini" ) // ChatCompletionContentPartRefusalType The type of the content part. diff --git a/internal/extproc/imagegeneration_processor.go b/internal/extproc/imagegeneration_processor.go new file mode 100644 index 0000000000..41bc2a4372 --- /dev/null +++ b/internal/extproc/imagegeneration_processor.go @@ -0,0 +1,440 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package extproc + +import ( + "cmp" + "context" + "encoding/json" + "fmt" + "log/slog" + "strconv" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + openaisdk "github.com/openai/openai-go/v2" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" + "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" + "github.com/envoyproxy/ai-gateway/internal/extproc/translator" + "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" + "github.com/envoyproxy/ai-gateway/internal/metrics" + tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" +) + +// ImageGenerationProcessorFactory returns a factory method to instantiate the image generation processor. +func ImageGenerationProcessorFactory(igm metrics.ImageGenerationMetrics) ProcessorFactory { + return func(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { + logger = logger.With("processor", "image-generation", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter)) + if !isUpstreamFilter { + return &imageGenerationProcessorRouterFilter{ + config: config, + tracer: tracing.ImageGenerationTracer(), + requestHeaders: requestHeaders, + logger: logger, + }, nil + } + return &imageGenerationProcessorUpstreamFilter{ + config: config, + requestHeaders: requestHeaders, + logger: logger, + metrics: igm, + }, nil + } +} + +// imageGenerationProcessorRouterFilter implements [Processor] for the `/v1/images/generations` endpoint. +// +// This is primarily used to select the route for the request based on the model name. +type imageGenerationProcessorRouterFilter struct { + passThroughProcessor + // upstreamFilter is the upstream filter that is used to process the request at the upstream filter. + // This will be updated when the request is retried. + // + // On the response handling path, we don't need to do any operation until successful, so we use the implementation + // of the upstream filter to handle the response at the router filter. + // + upstreamFilter Processor + logger *slog.Logger + config *processorConfig + requestHeaders map[string]string + // originalRequestBody is the original request body that is passed to the upstream filter. + // This is used to perform the transformation of the request body on the original input + // when the request is retried. + originalRequestBody *openaisdk.ImageGenerateParams + originalRequestBodyRaw []byte + // tracer is the tracer used for requests. + tracer tracing.ImageGenerationTracer + // span is the tracing span for this request, created in ProcessRequestBody. + span tracing.ImageGenerationSpan + // upstreamFilterCount is the number of upstream filters that have been processed. + // This is used to determine if the request is a retry request. + upstreamFilterCount int +} + +// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders]. +func (i *imageGenerationProcessorRouterFilter) ProcessResponseHeaders(ctx context.Context, headerMap *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) { + // If the request failed to route and/or immediate response was returned before the upstream filter was set, + // i.upstreamFilter can be nil. + if i.upstreamFilter != nil { // See the comment on the "upstreamFilter" field. + return i.upstreamFilter.ProcessResponseHeaders(ctx, headerMap) + } + return i.passThroughProcessor.ProcessResponseHeaders(ctx, headerMap) +} + +// ProcessResponseBody implements [Processor.ProcessResponseBody]. +func (i *imageGenerationProcessorRouterFilter) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (resp *extprocv3.ProcessingResponse, err error) { + // If the request failed to route and/or immediate response was returned before the upstream filter was set, + // i.upstreamFilter can be nil. + if i.upstreamFilter != nil { // See the comment on the "upstreamFilter" field. + resp, err = i.upstreamFilter.ProcessResponseBody(ctx, body) + } else { + resp, err = i.passThroughProcessor.ProcessResponseBody(ctx, body) + } + return +} + +// ProcessRequestBody implements [Processor.ProcessRequestBody]. +func (i *imageGenerationProcessorRouterFilter) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) { + model, body, err := parseOpenAIImageGenerationBody(rawBody) + if err != nil { + return nil, fmt.Errorf("failed to parse request body: %w", err) + } + + i.requestHeaders[internalapi.ModelNameHeaderKeyDefault] = model + + var additionalHeaders []*corev3.HeaderValueOption + additionalHeaders = append(additionalHeaders, &corev3.HeaderValueOption{ + // Set the model name to the request header with the key `x-ai-eg-model`. + Header: &corev3.HeaderValue{Key: internalapi.ModelNameHeaderKeyDefault, RawValue: []byte(model)}, + }, &corev3.HeaderValueOption{ + Header: &corev3.HeaderValue{Key: originalPathHeader, RawValue: []byte(i.requestHeaders[":path"])}, + }) + + i.originalRequestBody = body + i.originalRequestBodyRaw = rawBody.Body + + // Tracing may need to inject headers, so create a header mutation here. + headerMutation := &extprocv3.HeaderMutation{ + SetHeaders: additionalHeaders, + } + + i.span = i.tracer.StartSpanAndInjectHeaders( + ctx, + i.requestHeaders, + headerMutation, + body, + rawBody.Body, + ) + + return &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_RequestBody{ + RequestBody: &extprocv3.BodyResponse{ + Response: &extprocv3.CommonResponse{ + HeaderMutation: headerMutation, + ClearRouteCache: true, + }, + }, + }, + }, nil +} + +// imageGenerationProcessorUpstreamFilter implements [Processor] for the `/v1/images/generations` endpoint at the upstream filter. +// +// This is created per retry and handles the translation as well as the authentication of the request. +type imageGenerationProcessorUpstreamFilter struct { + logger *slog.Logger + config *processorConfig + requestHeaders map[string]string + responseHeaders map[string]string + responseEncoding string + modelNameOverride internalapi.ModelNameOverride + backendName string + handler backendauth.Handler + headerMutator *headermutator.HeaderMutator + originalRequestBodyRaw []byte + originalRequestBody *openaisdk.ImageGenerateParams + translator translator.ImageGenerationTranslator + // onRetry is true if this is a retry request at the upstream filter. + onRetry bool + // stream is set to true if the request is a streaming request (for GPT-Image-1). + stream bool + // cost is the cost of the request that is accumulated during the processing of the response. + costs translator.LLMTokenUsage + // metrics tracking. + metrics metrics.ImageGenerationMetrics + // span is the tracing span for this request, inherited from the router filter. + span tracing.ImageGenerationSpan +} + +// selectTranslator selects the translator based on the output schema. +// TODO: Implement proper translator selection once ImageGenerationTranslator is implemented +func (i *imageGenerationProcessorUpstreamFilter) selectTranslator(out filterapi.VersionedAPISchema) error { + switch out.Name { + case filterapi.APISchemaOpenAI: + i.translator = translator.NewImageGenerationOpenAIToOpenAITranslator(out.Version, i.modelNameOverride, i.span) + default: + return fmt.Errorf("unsupported API schema: backend=%s", out) + } + return nil +} + +// ProcessRequestHeaders implements [Processor.ProcessRequestHeaders]. +// +// At the upstream filter, we already have the original request body at request headers phase. +// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy +// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again +// to the extproc. +func (i *imageGenerationProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { + defer func() { + if err != nil { + i.metrics.RecordRequestCompletion(ctx, false, i.requestHeaders) + } + }() + + // Start tracking metrics for this request. + i.metrics.StartRequest(i.requestHeaders) + // Set the original model from the request body before any overrides + i.metrics.SetOriginalModel(i.originalRequestBody.Model) + // Set the request model for metrics from the original model or override if applied. + reqModel := cmp.Or(i.requestHeaders[internalapi.ModelNameHeaderKeyDefault], i.originalRequestBody.Model) + i.metrics.SetRequestModel(reqModel) + + // We force the body mutation in the following cases: + // * The request is a retry request because the body mutation might have happened the previous iteration. + forceBodyMutation := i.onRetry + headerMutation, bodyMutation, err := i.translator.RequestBody(i.originalRequestBodyRaw, i.originalRequestBody, forceBodyMutation) + if err != nil { + return nil, fmt.Errorf("failed to transform request: %w", err) + } + if headerMutation == nil { + headerMutation = &extprocv3.HeaderMutation{} + } + + // Apply header mutations from the route and also restore original headers on retry. + if h := i.headerMutator; h != nil { + if hm := i.headerMutator.Mutate(i.requestHeaders, i.onRetry); hm != nil { + headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, hm.RemoveHeaders...) + headerMutation.SetHeaders = append(headerMutation.SetHeaders, hm.SetHeaders...) + } + } + + for _, h := range headerMutation.SetHeaders { + i.requestHeaders[h.Header.Key] = string(h.Header.RawValue) + } + + if h := i.handler; h != nil { + if err = h.Do(ctx, i.requestHeaders, headerMutation, bodyMutation); err != nil { + return nil, fmt.Errorf("failed to do auth request: %w", err) + } + } + + var dm *structpb.Struct + if bm := bodyMutation.GetBody(); bm != nil { + dm = buildContentLengthDynamicMetadataOnRequest(len(bm)) + } + return &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extprocv3.HeadersResponse{ + Response: &extprocv3.CommonResponse{ + HeaderMutation: headerMutation, BodyMutation: bodyMutation, + Status: extprocv3.CommonResponse_CONTINUE_AND_REPLACE, + }, + }, + }, + DynamicMetadata: dm, + }, nil +} + +// ProcessRequestBody implements [Processor.ProcessRequestBody]. +func (i *imageGenerationProcessorUpstreamFilter) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { + panic("BUG: ProcessRequestBody should not be called in the upstream filter") +} + +// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders]. +func (i *imageGenerationProcessorUpstreamFilter) ProcessResponseHeaders(ctx context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { + defer func() { + if err != nil { + i.metrics.RecordRequestCompletion(ctx, false, i.requestHeaders) + } + }() + + i.responseHeaders = headersToMap(headers) + if enc := i.responseHeaders["content-encoding"]; enc != "" { + i.responseEncoding = enc + } + + headerMutation, err := i.translator.ResponseHeaders(i.responseHeaders) + if err != nil { + return nil, fmt.Errorf("failed to transform response headers: %w", err) + } + + var mode *extprocv3http.ProcessingMode + if i.stream && i.responseHeaders[":status"] == "200" { + // We only stream the response if the status code is 200 and the response is a stream. + mode = &extprocv3http.ProcessingMode{ResponseBodyMode: extprocv3http.ProcessingMode_STREAMED} + } + + return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extprocv3.HeadersResponse{ + Response: &extprocv3.CommonResponse{HeaderMutation: headerMutation}, + }, + }, ModeOverride: mode}, nil +} + +// ProcessResponseBody implements [Processor.ProcessResponseBody]. +func (i *imageGenerationProcessorUpstreamFilter) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { + recordRequestCompletionErr := false + defer func() { + if err != nil || recordRequestCompletionErr { + i.metrics.RecordRequestCompletion(ctx, false, i.requestHeaders) + return + } + if body.EndOfStream { + i.metrics.RecordRequestCompletion(ctx, true, i.requestHeaders) + } + }() + + // Decompress the body if needed using common utility. + decodingResult, err := decodeContentIfNeeded(body.Body, i.responseEncoding) + if err != nil { + return nil, err + } + + // Assume all responses have a valid status code header. + if code, _ := strconv.Atoi(i.responseHeaders[":status"]); !isGoodStatusCode(code) { + var headerMutation *extprocv3.HeaderMutation + var bodyMutation *extprocv3.BodyMutation + headerMutation, bodyMutation, err = i.translator.ResponseError(i.responseHeaders, decodingResult.reader) + if err != nil { + return nil, fmt.Errorf("failed to transform response error: %w", err) + } + if i.span != nil { + b := bodyMutation.GetBody() + if b == nil { + b = body.Body + } + i.span.EndSpanOnError(code, b) + } + // Mark so the deferred handler records failure. + recordRequestCompletionErr = true + return &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_ResponseBody{ + ResponseBody: &extprocv3.BodyResponse{ + Response: &extprocv3.CommonResponse{ + HeaderMutation: headerMutation, + BodyMutation: bodyMutation, + }, + }, + }, + }, nil + } + + // Translator response body transformation (if available) + var headerMutation *extprocv3.HeaderMutation + var bodyMutation *extprocv3.BodyMutation + var tokenUsage translator.LLMTokenUsage + var responseModel internalapi.ResponseModel + headerMutation, bodyMutation, tokenUsage, responseModel, err = i.translator.ResponseBody(i.responseHeaders, decodingResult.reader, body.EndOfStream) + if err != nil { + return nil, fmt.Errorf("failed to transform response: %w", err) + } + + // Remove content-encoding header if original body encoded but was mutated in the processor. + headerMutation = removeContentEncodingIfNeeded(headerMutation, bodyMutation, decodingResult.isEncoded) + + resp := &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_ResponseBody{ + ResponseBody: &extprocv3.BodyResponse{ + Response: &extprocv3.CommonResponse{ + HeaderMutation: headerMutation, + BodyMutation: bodyMutation, + }, + }, + }, + } + + i.costs.InputTokens += tokenUsage.InputTokens + i.costs.OutputTokens += tokenUsage.OutputTokens + i.costs.TotalTokens += tokenUsage.TotalTokens + + // Ensure response model is set before recording metrics so attributes include it. + i.metrics.SetResponseModel(responseModel) + // Update metrics with token usage (input/output only per OTEL spec). + i.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, tokenUsage.OutputTokens, i.requestHeaders) + // Record image generation metrics + i.metrics.RecordImageGeneration(ctx, i.requestHeaders) + + if body.EndOfStream && len(i.config.requestCosts) > 0 { + metadata, err := buildDynamicMetadata(i.config, &i.costs, i.requestHeaders, i.backendName) + if err != nil { + return nil, fmt.Errorf("failed to build dynamic metadata: %w", err) + } + resp.DynamicMetadata = metadata + } + + if body.EndOfStream && i.span != nil { + i.span.EndSpan() + } + return resp, nil +} + +// SetBackend implements [Processor.SetBackend]. +func (i *imageGenerationProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler backendauth.Handler, routeProcessor Processor) (err error) { + defer func() { + if err != nil { + i.metrics.RecordRequestCompletion(ctx, false, i.requestHeaders) + } + }() + pickedEndpoint, isEndpointPicker := i.requestHeaders[internalapi.EndpointPickerHeaderKey] + rp, ok := routeProcessor.(*imageGenerationProcessorRouterFilter) + if !ok { + panic("BUG: expected routeProcessor to be of type *imageGenerationProcessorRouterFilter") + } + rp.upstreamFilterCount++ + i.metrics.SetBackend(b) + i.modelNameOverride = b.ModelNameOverride + i.backendName = b.Name + if err = i.selectTranslator(b.Schema); err != nil { + return fmt.Errorf("failed to select translator: %w", err) + } + + i.handler = backendHandler + i.headerMutator = headermutator.NewHeaderMutator(b.HeaderMutation, rp.requestHeaders) + // Sync header with backend model so header-derived labels/CEL use the actual model. + if i.modelNameOverride != "" { + i.requestHeaders[internalapi.ModelNameHeaderKeyDefault] = i.modelNameOverride + // Update metrics with the overridden model + i.metrics.SetRequestModel(i.modelNameOverride) + } + i.originalRequestBody = rp.originalRequestBody + i.originalRequestBodyRaw = rp.originalRequestBodyRaw + i.onRetry = rp.upstreamFilterCount > 1 + + // Set streaming flag for GPT-Image-1 requests + // Image generation streaming not supported in current SDK params; keep false. + i.stream = false + + if isEndpointPicker { + if i.logger.Enabled(ctx, slog.LevelDebug) { + i.logger.Debug("selected backend", slog.String("picked_endpoint", pickedEndpoint), slog.String("backendName", b.Name), slog.String("modelNameOverride", i.modelNameOverride), slog.Bool("stream", i.stream)) + } + } + rp.upstreamFilter = i + i.span = rp.span + return +} + +func parseOpenAIImageGenerationBody(body *extprocv3.HttpBody) (modelName string, rb *openaisdk.ImageGenerateParams, err error) { + var openAIReq openaisdk.ImageGenerateParams + if err := json.Unmarshal(body.Body, &openAIReq); err != nil { + return "", nil, fmt.Errorf("failed to unmarshal body: %w", err) + } + return openAIReq.Model, &openAIReq, nil +} diff --git a/internal/extproc/imagegeneration_processor_test.go b/internal/extproc/imagegeneration_processor_test.go new file mode 100644 index 0000000000..2e9c495115 --- /dev/null +++ b/internal/extproc/imagegeneration_processor_test.go @@ -0,0 +1,520 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package extproc + +import ( + "bytes" + "compress/gzip" + "context" + "encoding/json" + "errors" + "io" + "log/slog" + "testing" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + openaisdk "github.com/openai/openai-go/v2" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/extproc/translator" + "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" + "github.com/envoyproxy/ai-gateway/internal/llmcostcel" + tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" +) + +func TestImageGeneration_Schema(t *testing.T) { + t.Run("supported openai / on route", func(t *testing.T) { + cfg := &processorConfig{} + routeFilter, err := ImageGenerationProcessorFactory(nil)(cfg, nil, slog.Default(), tracing.NoopTracing{}, false) + require.NoError(t, err) + require.NotNil(t, routeFilter) + require.IsType(t, &imageGenerationProcessorRouterFilter{}, routeFilter) + }) + t.Run("supported openai / on upstream", func(t *testing.T) { + cfg := &processorConfig{} + routeFilter, err := ImageGenerationProcessorFactory(nil)(cfg, nil, slog.Default(), tracing.NoopTracing{}, true) + require.NoError(t, err) + require.NotNil(t, routeFilter) + require.IsType(t, &imageGenerationProcessorUpstreamFilter{}, routeFilter) + }) +} + +func Test_imageGenerationProcessorUpstreamFilter_SelectTranslator(t *testing.T) { + c := &imageGenerationProcessorUpstreamFilter{} + t.Run("unsupported", func(t *testing.T) { + err := c.selectTranslator(filterapi.VersionedAPISchema{Name: "Bar", Version: "v123"}) + require.ErrorContains(t, err, "unsupported API schema: backend={Bar v123}") + }) + t.Run("supported openai", func(t *testing.T) { + err := c.selectTranslator(filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}) + require.NoError(t, err) + require.NotNil(t, c.translator) + }) +} + +type mockImageGenerationTracer struct { + tracing.NoopImageGenerationTracer + startSpanCalled bool + returnedSpan tracing.ImageGenerationSpan +} + +func (m *mockImageGenerationTracer) StartSpanAndInjectHeaders(_ context.Context, _ map[string]string, headerMutation *extprocv3.HeaderMutation, _ *openaisdk.ImageGenerateParams, _ []byte) tracing.ImageGenerationSpan { + m.startSpanCalled = true + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + Header: &corev3.HeaderValue{ + Key: "tracing-header", + Value: "1", + }, + }) + if m.returnedSpan != nil { + return m.returnedSpan + } + return nil +} + +// Mock span for image generation tests +type mockImageGenerationSpan struct { + endSpanCalled bool + errorStatus int + errBody string +} + +func (m *mockImageGenerationSpan) EndSpan() { + m.endSpanCalled = true +} + +func (m *mockImageGenerationSpan) EndSpanOnError(status int, body []byte) { + m.errorStatus = status + m.errBody = string(body) +} + +func (m *mockImageGenerationSpan) RecordResponse(_ *openaisdk.ImagesResponse) { + // Mock implementation +} + +func Test_imageGenerationProcessorRouterFilter_ProcessRequestBody(t *testing.T) { + t.Run("response pass-through and delegation", func(t *testing.T) { + // When upstreamFilter is nil, router should pass-through using passThroughProcessor. + rf := &imageGenerationProcessorRouterFilter{} + prh, err := rf.ProcessResponseHeaders(t.Context(), &corev3.HeaderMap{}) + require.NoError(t, err) + require.NotNil(t, prh) + + prb, err := rf.ProcessResponseBody(t.Context(), &extprocv3.HttpBody{Body: []byte("abc")}) + require.NoError(t, err) + require.NotNil(t, prb) + + // When upstreamFilter is set, router should delegate to upstream filter. + upstream := &mockProcessor{ + t: t, expHeaderMap: &corev3.HeaderMap{}, expBody: &extprocv3.HttpBody{Body: []byte("abc")}, + retProcessingResponse: &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{}}, + } + rf.upstreamFilter = upstream + prh2, err := rf.ProcessResponseHeaders(t.Context(), &corev3.HeaderMap{}) + require.NoError(t, err) + require.Equal(t, upstream.retProcessingResponse, prh2) + + upstream.retProcessingResponse = &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseBody{}} + prb2, err := rf.ProcessResponseBody(t.Context(), &extprocv3.HttpBody{Body: []byte("abc")}) + require.NoError(t, err) + require.Equal(t, upstream.retProcessingResponse, prb2) + }) + t.Run("body parser error", func(t *testing.T) { + p := &imageGenerationProcessorRouterFilter{} + _, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: []byte("not-json")}) + require.ErrorContains(t, err, "failed to parse request body") + }) + + t.Run("ok", func(t *testing.T) { + headers := map[string]string{":path": "/v1/images/generations"} + const modelKey = "x-ai-eg-model" + p := &imageGenerationProcessorRouterFilter{ + config: &processorConfig{}, + requestHeaders: headers, + logger: slog.Default(), + tracer: tracing.NoopTracing{}.ImageGenerationTracer(), + } + body := imageGenerationBodyFromModel(t, "dall-e-3") + resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: body}) + require.NoError(t, err) + require.NotNil(t, resp) + re, ok := resp.Response.(*extprocv3.ProcessingResponse_RequestBody) + require.True(t, ok) + require.NotNil(t, re) + require.NotNil(t, re.RequestBody) + setHeaders := re.RequestBody.GetResponse().GetHeaderMutation().SetHeaders + require.Len(t, setHeaders, 2) + require.Equal(t, modelKey, setHeaders[0].Header.Key) + require.Equal(t, "dall-e-3", string(setHeaders[0].Header.RawValue)) + require.Equal(t, "x-ai-eg-original-path", setHeaders[1].Header.Key) + require.Equal(t, "/v1/images/generations", string(setHeaders[1].Header.RawValue)) + }) + + t.Run("span creation", func(t *testing.T) { + headers := map[string]string{":path": "/v1/images/generations"} + span := &mockImageGenerationSpan{} + mockTracerInstance := &mockImageGenerationTracer{returnedSpan: span} + + p := &imageGenerationProcessorRouterFilter{ + config: &processorConfig{}, + requestHeaders: headers, + logger: slog.Default(), + tracer: mockTracerInstance, + } + + body := imageGenerationBodyFromModel(t, "dall-e-3") + resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: body}) + require.NoError(t, err) + require.NotNil(t, resp) + require.True(t, mockTracerInstance.startSpanCalled) + require.Equal(t, span, p.span) + + // Verify headers are injected. + re, ok := resp.Response.(*extprocv3.ProcessingResponse_RequestBody) + require.True(t, ok) + headerMutation := re.RequestBody.GetResponse().GetHeaderMutation() + require.Contains(t, headerMutation.SetHeaders, &corev3.HeaderValueOption{ + Header: &corev3.HeaderValue{ + Key: "tracing-header", + Value: "1", + }, + }) + }) +} + +func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseHeaders(t *testing.T) { + t.Run("error translation", func(t *testing.T) { + mm := &mockImageGenerationMetrics{} + mt := &mockImageGenerationTranslator{t: t, expHeaders: make(map[string]string)} + p := &imageGenerationProcessorUpstreamFilter{ + translator: mt, + metrics: mm, + logger: slog.Default(), + } + mt.retErr = errors.New("test error") + _, err := p.ProcessResponseHeaders(t.Context(), nil) + require.ErrorContains(t, err, "test error") + mm.RequireRequestFailure(t) + }) + t.Run("ok", func(t *testing.T) { + inHeaders := &corev3.HeaderMap{ + Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}, {Key: "dog", RawValue: []byte("cat")}}, + } + expHeaders := map[string]string{"foo": "bar", "dog": "cat"} + mm := &mockImageGenerationMetrics{} + mt := &mockImageGenerationTranslator{t: t, expHeaders: expHeaders} + p := &imageGenerationProcessorUpstreamFilter{ + translator: mt, + metrics: mm, + logger: slog.Default(), + } + res, err := p.ProcessResponseHeaders(t.Context(), inHeaders) + require.NoError(t, err) + commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response + require.Equal(t, mt.retHeaderMutation, commonRes.HeaderMutation) + mm.RequireRequestNotCompleted(t) + }) +} + +func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) { + t.Run("error translation", func(t *testing.T) { + mm := &mockImageGenerationMetrics{} + mt := &mockImageGenerationTranslator{t: t} + p := &imageGenerationProcessorUpstreamFilter{ + translator: mt, + metrics: mm, + logger: slog.Default(), + } + mt.retErr = errors.New("test error") + _, err := p.ProcessResponseBody(t.Context(), &extprocv3.HttpBody{}) + require.ErrorContains(t, err, "test error") + mm.RequireRequestFailure(t) + require.Zero(t, mm.tokenUsageCount) + }) + t.Run("ok", func(t *testing.T) { + inBody := &extprocv3.HttpBody{Body: []byte("some-body"), EndOfStream: true} + expBodyMut := &extprocv3.BodyMutation{} + expHeadMut := &extprocv3.HeaderMutation{} + mm := &mockImageGenerationMetrics{} + mt := &mockImageGenerationTranslator{ + t: t, expResponseBody: inBody, + retBodyMutation: expBodyMut, retHeaderMutation: expHeadMut, + retUsedToken: translator.LLMTokenUsage{OutputTokens: 123, InputTokens: 1}, + } + + celProgInt, err := llmcostcel.NewProgram("54321") + require.NoError(t, err) + celProgUint, err := llmcostcel.NewProgram("uint(9999)") + require.NoError(t, err) + p := &imageGenerationProcessorUpstreamFilter{ + translator: mt, + logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), + metrics: mm, + config: &processorConfig{ + requestCosts: []processorConfigRequestCost{ + {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}}, + {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}}, + { + celProg: celProgInt, + LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"}, + }, + { + celProg: celProgUint, + LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_uint"}, + }, + }, + }, + requestHeaders: map[string]string{internalapi.ModelNameHeaderKeyDefault: "ai_gateway_llm"}, + responseHeaders: map[string]string{":status": "200"}, + backendName: "some_backend", + modelNameOverride: "ai_gateway_llm", + } + res, err := p.ProcessResponseBody(t.Context(), inBody) + require.NoError(t, err) + commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response + require.Equal(t, expBodyMut, commonRes.BodyMutation) + require.Equal(t, expHeadMut, commonRes.HeaderMutation) + mm.RequireRequestSuccess(t) + require.Equal(t, 124, mm.tokenUsageCount) // 1 input + 123 output + + md := res.DynamicMetadata + require.NotNil(t, md) + require.Equal(t, float64(123), md.Fields[internalapi.AIGatewayFilterMetadataNamespace]. + GetStructValue().Fields["output_token_usage"].GetNumberValue()) + require.Equal(t, float64(1), md.Fields[internalapi.AIGatewayFilterMetadataNamespace]. + GetStructValue().Fields["input_token_usage"].GetNumberValue()) + require.Equal(t, float64(54321), md.Fields[internalapi.AIGatewayFilterMetadataNamespace]. + GetStructValue().Fields["cel_int"].GetNumberValue()) + require.Equal(t, float64(9999), md.Fields[internalapi.AIGatewayFilterMetadataNamespace]. + GetStructValue().Fields["cel_uint"].GetNumberValue()) + require.Equal(t, "ai_gateway_llm", md.Fields[internalapi.AIGatewayFilterMetadataNamespace].GetStructValue().Fields["model_name_override"].GetStringValue()) + require.Equal(t, "some_backend", md.Fields[internalapi.AIGatewayFilterMetadataNamespace].GetStructValue().Fields["backend_name"].GetStringValue()) + }) + + // Verify we record failure for non-2xx responses and do it exactly once (defer suppressed), and span records error. + t.Run("non-2xx status failure once", func(t *testing.T) { + inBody := &extprocv3.HttpBody{Body: []byte("error-body"), EndOfStream: true} + expHeadMut := &extprocv3.HeaderMutation{} + expBodyMut := &extprocv3.BodyMutation{} + mm := &mockImageGenerationMetrics{} + mt := &mockImageGenerationTranslator{t: t, expResponseBody: inBody, retHeaderMutation: expHeadMut, retBodyMutation: expBodyMut} + p := &imageGenerationProcessorUpstreamFilter{ + translator: mt, + metrics: mm, + responseHeaders: map[string]string{":status": "500"}, + logger: slog.Default(), + span: &mockImageGenerationSpan{}, + } + res, err := p.ProcessResponseBody(t.Context(), inBody) + require.NoError(t, err) + commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response + require.Equal(t, expBodyMut, commonRes.BodyMutation) + require.Equal(t, expHeadMut, commonRes.HeaderMutation) + mm.RequireRequestFailure(t) + // assert span error recorded + s := p.span.(*mockImageGenerationSpan) + require.Equal(t, 500, s.errorStatus) + require.Equal(t, "error-body", s.errBody) + }) + + // Verify content-encoding header is removed when encoded body is mutated. + t.Run("gzip encoded body with mutation removes content-encoding", func(t *testing.T) { + var gz bytes.Buffer + zw := gzip.NewWriter(&gz) + _, err := zw.Write([]byte("encoded-body")) + require.NoError(t, err) + require.NoError(t, zw.Close()) + inBody := &extprocv3.HttpBody{Body: gz.Bytes(), EndOfStream: true} + mm := &mockImageGenerationMetrics{} + mt := &mockImageGenerationTranslator{ + // translator returns a non-nil body mutation indicating processor changed body + retBodyMutation: &extprocv3.BodyMutation{Mutation: &extprocv3.BodyMutation_Body{Body: []byte("changed")}}, + retHeaderMutation: &extprocv3.HeaderMutation{}, + } + p := &imageGenerationProcessorUpstreamFilter{ + translator: mt, + metrics: mm, + logger: slog.Default(), + responseHeaders: map[string]string{":status": "200"}, + responseEncoding: "gzip", + config: &processorConfig{}, + } + res, err := p.ProcessResponseBody(t.Context(), inBody) + require.NoError(t, err) + commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response + reqHM := commonRes.HeaderMutation + require.Contains(t, reqHM.RemoveHeaders, "content-encoding") + mm.RequireRequestSuccess(t) + }) +} + +func Test_imageGenerationProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { + t.Run("ok with auth handler and header mutator", func(t *testing.T) { + headers := map[string]string{":path": "/v1/images/generations", internalapi.ModelNameHeaderKeyDefault: "dall-e-3"} + mm := &mockImageGenerationMetrics{} + body := &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModel("dall-e-3"), Prompt: "a cat"} + mt := &mockImageGenerationTranslator{t: t, expRequestBody: body} + p := &imageGenerationProcessorUpstreamFilter{ + config: &processorConfig{}, + requestHeaders: headers, + logger: slog.Default(), + metrics: mm, + originalRequestBodyRaw: imageGenerationBodyFromModel(t, "dall-e-3"), + originalRequestBody: body, + handler: &mockBackendAuthHandler{}, + translator: mt, + } + resp, err := p.ProcessRequestHeaders(t.Context(), nil) + require.NoError(t, err) + require.NotNil(t, resp) + mm.RequireRequestNotCompleted(t) + mm.RequireSelectedModel(t, "dall-e-3") + }) + + t.Run("auth handler error path records failure", func(t *testing.T) { + headers := map[string]string{":path": "/v1/images/generations", internalapi.ModelNameHeaderKeyDefault: "dall-e-3"} + mm := &mockImageGenerationMetrics{} + body := &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModel("dall-e-3"), Prompt: "a cat"} + mt := &mockImageGenerationTranslator{t: t, expRequestBody: body} + p := &imageGenerationProcessorUpstreamFilter{ + config: &processorConfig{}, + requestHeaders: headers, + logger: slog.Default(), + metrics: mm, + originalRequestBodyRaw: imageGenerationBodyFromModel(t, "dall-e-3"), + originalRequestBody: body, + // handler returns error to simulate backend auth failure + handler: &mockBackendAuthHandlerError{}, + translator: mt, + } + _, err := p.ProcessRequestHeaders(t.Context(), nil) + require.Error(t, err) + mm.RequireRequestFailure(t) + }) + + // Ensure upstream ProcessRequestBody panics as documented and streaming flag behavior is off. + t.Run("upstream body panic and streaming off", func(t *testing.T) { + p := &imageGenerationProcessorUpstreamFilter{} + require.False(t, p.stream) + require.Panics(t, func() { + _, _ = p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{}) + }) + }) +} + +func Test_imageGenerationProcessorUpstreamFilter_SetBackend(t *testing.T) { + headers := map[string]string{":path": "/v1/images/generations"} + mm := &mockImageGenerationMetrics{} + p := &imageGenerationProcessorUpstreamFilter{ + config: &processorConfig{}, + requestHeaders: headers, + logger: slog.Default(), + metrics: mm, + } + + // Unsupported schema. + err := p.SetBackend(t.Context(), &filterapi.Backend{ + Name: "some-backend", + Schema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"}, + }, nil, &imageGenerationProcessorRouterFilter{}) + require.ErrorContains(t, err, "unsupported API schema: backend={some-schema v10.0}") + mm.RequireRequestFailure(t) + mm.RequireTokensRecorded(t, 0) + mm.RequireSelectedBackend(t, "some-backend") + + // Supported OpenAI schema. + rp := &imageGenerationProcessorRouterFilter{originalRequestBody: &openaisdk.ImageGenerateParams{}} + p2 := &imageGenerationProcessorUpstreamFilter{ + config: &processorConfig{}, + requestHeaders: map[string]string{internalapi.ModelNameHeaderKeyDefault: "gpt-image-1-mini"}, + logger: slog.Default(), + metrics: &mockImageGenerationMetrics{}, + } + err = p2.SetBackend(t.Context(), &filterapi.Backend{ + Name: "openai", + Schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI, Version: "v1"}, + ModelNameOverride: "gpt-image-1", + }, nil, rp) + require.NoError(t, err) + require.Equal(t, "gpt-image-1", p2.requestHeaders[internalapi.ModelNameHeaderKeyDefault]) +} + +func TestImageGeneration_ParseBody(t *testing.T) { + t.Run("ok", func(t *testing.T) { + jsonBody := `{"model":"gpt-image-1-mini","prompt":"a cat","size":"1024x1024","quality":"low"}` + modelName, rb, err := parseOpenAIImageGenerationBody(&extprocv3.HttpBody{Body: []byte(jsonBody)}) + require.NoError(t, err) + require.Equal(t, "gpt-image-1-mini", modelName) + require.NotNil(t, rb) + require.Equal(t, "gpt-image-1-mini", rb.Model) + require.Equal(t, "a cat", rb.Prompt) + }) + t.Run("error", func(t *testing.T) { + modelName, rb, err := parseOpenAIImageGenerationBody(&extprocv3.HttpBody{}) + require.Error(t, err) + require.Empty(t, modelName) + require.Nil(t, rb) + }) +} + +// Mock translator for image generation tests +type mockImageGenerationTranslator struct { + t *testing.T + expRequestBody *openaisdk.ImageGenerateParams + expResponseBody *extprocv3.HttpBody + expHeaders map[string]string + expForceRequestBodyMutation bool + retErr error + retHeaderMutation *extprocv3.HeaderMutation + retBodyMutation *extprocv3.BodyMutation + retUsedToken translator.LLMTokenUsage + retResponseModel internalapi.ResponseModel +} + +func (m *mockImageGenerationTranslator) RequestBody(_ []byte, req *openaisdk.ImageGenerateParams, forceBodyMutation bool) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, error) { + if m.expRequestBody != nil { + require.Equal(m.t, m.expRequestBody, req) + } + if m.expForceRequestBodyMutation { + require.True(m.t, forceBodyMutation) + } + return m.retHeaderMutation, m.retBodyMutation, m.retErr +} + +func (m *mockImageGenerationTranslator) ResponseHeaders(headers map[string]string) (*extprocv3.HeaderMutation, error) { + if m.expHeaders != nil { + for k, v := range m.expHeaders { + require.Equal(m.t, v, headers[k]) + } + } + return m.retHeaderMutation, m.retErr +} + +func (m *mockImageGenerationTranslator) ResponseBody(headers map[string]string, body io.Reader, endOfStream bool) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, translator.LLMTokenUsage, internalapi.ResponseModel, error) { + _ = headers + if m.expResponseBody != nil { + bodyBytes, _ := io.ReadAll(body) + require.Equal(m.t, m.expResponseBody.Body, bodyBytes) + require.Equal(m.t, m.expResponseBody.EndOfStream, endOfStream) + } + return m.retHeaderMutation, m.retBodyMutation, m.retUsedToken, m.retResponseModel, m.retErr +} + +func (m *mockImageGenerationTranslator) ResponseError(headers map[string]string, body io.Reader) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, error) { + _ = headers + _ = body + return m.retHeaderMutation, m.retBodyMutation, m.retErr +} + +// imageGenerationBodyFromModel returns a minimal valid image generation request for tests. +func imageGenerationBodyFromModel(t *testing.T, model string) []byte { + t.Helper() + b, err := json.Marshal(&openaisdk.ImageGenerateParams{Model: model, Prompt: "a cat"}) + require.NoError(t, err) + return b +} diff --git a/internal/extproc/mocks_test.go b/internal/extproc/mocks_test.go index 154024e73a..1c4dda246a 100644 --- a/internal/extproc/mocks_test.go +++ b/internal/extproc/mocks_test.go @@ -511,3 +511,81 @@ type mockBackendAuthHandler struct{} func (m *mockBackendAuthHandler) Do(context.Context, map[string]string, *extprocv3.HeaderMutation, *extprocv3.BodyMutation) error { return nil } + +// mockBackendAuthHandlerError returns error on Do. +type mockBackendAuthHandlerError struct{} + +func (m *mockBackendAuthHandlerError) Do(context.Context, map[string]string, *extprocv3.HeaderMutation, *extprocv3.BodyMutation) error { + return io.EOF +} + +// mockImageGenerationMetrics implements [metrics.ImageGenerationMetrics] for testing. +type mockImageGenerationMetrics struct { + requestSuccessCount int + requestErrorCount int + model string + backend string + tokenUsageCount int +} + +func (m *mockImageGenerationMetrics) StartRequest(map[string]string) {} +func (m *mockImageGenerationMetrics) SetOriginalModel(originalModel string) { + m.model = originalModel +} + +func (m *mockImageGenerationMetrics) SetRequestModel(requestModel string) { + m.model = requestModel +} + +func (m *mockImageGenerationMetrics) SetResponseModel(responseModel string) { + m.model = responseModel +} + +func (m *mockImageGenerationMetrics) SetModel(_ string, responseModel string) { + m.model = responseModel +} +func (m *mockImageGenerationMetrics) SetBackend(b *filterapi.Backend) { m.backend = b.Name } +func (m *mockImageGenerationMetrics) RecordTokenUsage(_ context.Context, input, output uint32, _ map[string]string) { + m.tokenUsageCount += int(input + output) +} + +func (m *mockImageGenerationMetrics) RecordRequestCompletion(_ context.Context, success bool, _ map[string]string) { + if success { + m.requestSuccessCount++ + } else { + m.requestErrorCount++ + } +} + +func (m *mockImageGenerationMetrics) RecordImageGeneration(_ context.Context, _ map[string]string) { +} + +func (m *mockImageGenerationMetrics) RequireRequestFailure(t *testing.T) { + require.Equal(t, 0, m.requestSuccessCount) + require.Equal(t, 1, m.requestErrorCount) +} + +func (m *mockImageGenerationMetrics) RequireRequestNotCompleted(t *testing.T) { + require.Equal(t, 0, m.requestSuccessCount) + require.Equal(t, 0, m.requestErrorCount) +} + +func (m *mockImageGenerationMetrics) RequireRequestSuccess(t *testing.T) { + require.Equal(t, 1, m.requestSuccessCount) + require.Equal(t, 0, m.requestErrorCount) +} + +func (m *mockImageGenerationMetrics) RequireSelectedModel(t *testing.T, model string) { + require.Equal(t, model, m.model) +} + +func (m *mockImageGenerationMetrics) RequireSelectedBackend(t *testing.T, backend string) { + require.Equal(t, backend, m.backend) +} + +func (m *mockImageGenerationMetrics) RequireTokensRecorded(t *testing.T, count int) { + require.Equal(t, count, m.tokenUsageCount) +} + +// Ensure mock implements the interface at compile-time. +var _ metrics.ImageGenerationMetrics = &mockImageGenerationMetrics{} diff --git a/internal/extproc/translator/imagegeneration_openai_openai.go b/internal/extproc/translator/imagegeneration_openai_openai.go new file mode 100644 index 0000000000..b47d1eabb2 --- /dev/null +++ b/internal/extproc/translator/imagegeneration_openai_openai.go @@ -0,0 +1,152 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "cmp" + "encoding/json" + "fmt" + "io" + "path" + "strconv" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + openaisdk "github.com/openai/openai-go/v2" + "github.com/tidwall/sjson" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/internalapi" + tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" +) + +// NewImageGenerationOpenAIToOpenAITranslator implements [Factory] for OpenAI to OpenAI image generation translation. +func NewImageGenerationOpenAIToOpenAITranslator(apiVersion string, modelNameOverride internalapi.ModelNameOverride, span tracing.ImageGenerationSpan) ImageGenerationTranslator { + return &openAIToOpenAIImageGenerationTranslator{modelNameOverride: modelNameOverride, path: path.Join("/", apiVersion, "images/generations"), span: span} +} + +// openAIToOpenAIImageGenerationTranslator implements [ImageGenerationTranslator] for /v1/images/generations. +type openAIToOpenAIImageGenerationTranslator struct { + modelNameOverride internalapi.ModelNameOverride + // The path of the images generations endpoint to be used for the request. It is prefixed with the OpenAI path prefix. + path string + // span is the tracing span for this request, inherited from the router filter. + span tracing.ImageGenerationSpan + // requestModel stores the effective model for this request (override or provided) + // so we can attribute metrics later; the OpenAI Images response omits a model field. + requestModel internalapi.RequestModel +} + +// RequestBody implements [ImageGenerationTranslator.RequestBody]. +func (o *openAIToOpenAIImageGenerationTranslator) RequestBody(original []byte, p *openaisdk.ImageGenerateParams, forceBodyMutation bool) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, err error, +) { + var newBody []byte + if o.modelNameOverride != "" { + // If modelName is set we override the model to be used for the request. + newBody, err = sjson.SetBytesOptions(original, "model", o.modelNameOverride, sjsonOptions) + if err != nil { + return nil, nil, fmt.Errorf("failed to set model name: %w", err) + } + } + // Persist the effective model used. The Images endpoint omits model in responses, + // so we derive it from the request (or override) for downstream metrics. + o.requestModel = cmp.Or(o.modelNameOverride, p.Model) + + // Always set the path header to the images generations endpoint so that the request is routed correctly. + headerMutation = &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + {Header: &corev3.HeaderValue{ + Key: ":path", + RawValue: []byte(o.path), + }}, + }, + } + + if forceBodyMutation && len(newBody) == 0 { + newBody = original + } + + if len(newBody) > 0 { + bodyMutation = &extprocv3.BodyMutation{ + Mutation: &extprocv3.BodyMutation_Body{Body: newBody}, + } + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{Header: &corev3.HeaderValue{ + Key: "content-length", + RawValue: []byte(strconv.Itoa(len(newBody))), + }}) + } + return +} + +// ResponseError implements [ImageGenerationTranslator.ResponseError] +// For OpenAI based backend we return the OpenAI error type as is. +// If connection fails the error body is translated to OpenAI error type for events such as HTTP 503 or 504. +func (o *openAIToOpenAIImageGenerationTranslator) ResponseError(respHeaders map[string]string, body io.Reader) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, err error, +) { + statusCode := respHeaders[statusHeaderName] + // Read the upstream error body regardless of content-type. Some backends may mislabel it. + buf, err := io.ReadAll(body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read error body: %w", err) + } + // If upstream already returned JSON, preserve it as-is. + if json.Valid(buf) { + return nil, nil, nil + } + // Otherwise, wrap the plain-text (or non-JSON) error into OpenAI REST error schema. + openaiError := struct { + Error openai.ErrorType `json:"error"` + }{ + Error: openai.ErrorType{ + Type: openAIBackendError, + Message: string(buf), + Code: &statusCode, + }, + } + mut := &extprocv3.BodyMutation_Body{} + mut.Body, err = json.Marshal(openaiError) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal error body: %w", err) + } + headerMutation = &extprocv3.HeaderMutation{} + // Ensure downstream sees a JSON error payload + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{Header: &corev3.HeaderValue{ + Key: contentTypeHeaderName, + RawValue: []byte(jsonContentType), + }}) + setContentLength(headerMutation, mut.Body) + return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, nil +} + +// ResponseHeaders implements [ImageGenerationTranslator.ResponseHeaders]. +func (o *openAIToOpenAIImageGenerationTranslator) ResponseHeaders(map[string]string) (headerMutation *extprocv3.HeaderMutation, err error) { + return nil, nil +} + +// ResponseBody implements [ImageGenerationTranslator.ResponseBody]. +func (o *openAIToOpenAIImageGenerationTranslator) ResponseBody(_ map[string]string, body io.Reader, _ bool) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage LLMTokenUsage, responseModel internalapi.ResponseModel, err error, +) { + // Decode using OpenAI SDK v2 schema to avoid drift. + resp := &openaisdk.ImagesResponse{} + if err := json.NewDecoder(body).Decode(&resp); err != nil { + return nil, nil, tokenUsage, responseModel, fmt.Errorf("failed to decode response body: %w", err) + } + + // Populate token usage if provided (GPT-Image-1); otherwise remain zero. + if resp.JSON.Usage.Valid() { + tokenUsage.InputTokens = uint32(resp.Usage.InputTokens) //nolint:gosec + tokenUsage.OutputTokens = uint32(resp.Usage.OutputTokens) //nolint:gosec + tokenUsage.TotalTokens = uint32(resp.Usage.TotalTokens) //nolint:gosec + } + + // Provide response model for metrics + responseModel = o.requestModel + + return +} diff --git a/internal/extproc/translator/imagegeneration_openai_openai_test.go b/internal/extproc/translator/imagegeneration_openai_openai_test.go new file mode 100644 index 0000000000..b7c1612f9c --- /dev/null +++ b/internal/extproc/translator/imagegeneration_openai_openai_test.go @@ -0,0 +1,160 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "bytes" + "encoding/json" + "io" + "testing" + + openaisdk "github.com/openai/openai-go/v2" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +func TestOpenAIToOpenAIImageTranslator_RequestBody_ModelOverrideAndPath(t *testing.T) { + tr := NewImageGenerationOpenAIToOpenAITranslator("v1", "gpt-image-1", nil) + req := &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModelDallE3, Prompt: "a cat"} + original, _ := json.Marshal(req) + + hm, bm, err := tr.RequestBody(original, req, false) + require.NoError(t, err) + require.NotNil(t, hm) + require.Len(t, hm.SetHeaders, 2) // path and content-length headers + require.Equal(t, ":path", hm.SetHeaders[0].Header.Key) + require.Equal(t, "/v1/images/generations", string(hm.SetHeaders[0].Header.RawValue)) + require.Equal(t, "content-length", hm.SetHeaders[1].Header.Key) + + require.NotNil(t, bm) + mutated := bm.GetBody() + var got openaisdk.ImageGenerateParams + require.NoError(t, json.Unmarshal(mutated, &got)) + require.Equal(t, "gpt-image-1", got.Model) +} + +func TestOpenAIToOpenAIImageTranslator_RequestBody_ForceMutation(t *testing.T) { + tr := NewImageGenerationOpenAIToOpenAITranslator("v1", "", nil) + req := &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModelDallE2, Prompt: "a cat"} + original, _ := json.Marshal(req) + + hm, bm, err := tr.RequestBody(original, req, true) + require.NoError(t, err) + require.NotNil(t, hm) + // Content-Length is set only when body mutated; with force it should be mutated to original. + foundCL := false + for _, h := range hm.SetHeaders { + if h.Header.Key == "content-length" { + foundCL = true + break + } + } + require.True(t, foundCL) + require.NotNil(t, bm) + require.Equal(t, original, bm.GetBody()) +} + +func TestOpenAIToOpenAIImageTranslator_ResponseError_NonJSON(t *testing.T) { + tr := NewImageGenerationOpenAIToOpenAITranslator("v1", "", nil) + headers := map[string]string{contentTypeHeaderName: "text/plain", statusHeaderName: "503"} + hm, bm, err := tr.ResponseError(headers, bytes.NewReader([]byte("backend error"))) + require.NoError(t, err) + require.NotNil(t, hm) + require.NotNil(t, bm) + + // Body should be OpenAI error JSON + var got struct { + Error openai.ErrorType `json:"error"` + } + require.NoError(t, json.Unmarshal(bm.GetBody(), &got)) + require.Equal(t, openAIBackendError, got.Error.Type) +} + +func TestOpenAIToOpenAIImageTranslator_ResponseBody_OK(t *testing.T) { + tr := NewImageGenerationOpenAIToOpenAITranslator("v1", "", nil) + resp := &openaisdk.ImagesResponse{Size: openaisdk.ImagesResponseSize1024x1024} + buf, _ := json.Marshal(resp) + hm, bm, usage, responseModel, err := tr.ResponseBody(map[string]string{}, bytes.NewReader(buf), true) + require.NoError(t, err) + require.Nil(t, hm) + require.Nil(t, bm) + require.Equal(t, uint32(0), usage.InputTokens) + require.Equal(t, uint32(0), usage.TotalTokens) + require.Empty(t, responseModel) +} + +func TestOpenAIToOpenAIImageTranslator_RequestBody_NoOverrideNoForce(t *testing.T) { + tr := NewImageGenerationOpenAIToOpenAITranslator("v1", "", nil) + req := &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModelDallE2, Prompt: "a cat"} + original, _ := json.Marshal(req) + + hm, bm, err := tr.RequestBody(original, req, false) + require.NoError(t, err) + require.NotNil(t, hm) + // Only path header present; content-length should not be set when no mutation + require.Len(t, hm.SetHeaders, 1) + require.Equal(t, ":path", hm.SetHeaders[0].Header.Key) + require.Nil(t, bm) +} + +func TestOpenAIToOpenAIImageTranslator_ResponseError_JSONPassthrough(t *testing.T) { + tr := NewImageGenerationOpenAIToOpenAITranslator("v1", "", nil) + headers := map[string]string{contentTypeHeaderName: jsonContentType, statusHeaderName: "500"} + // Already JSON — should be passed through (no mutation) + hm, bm, err := tr.ResponseError(headers, bytes.NewReader([]byte(`{"error":"msg"}`))) + require.NoError(t, err) + require.Nil(t, hm) + require.Nil(t, bm) +} + +type errReader struct{} + +func (errReader) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } + +func TestOpenAIToOpenAIImageTranslator_ResponseError_ReadError(t *testing.T) { + tr := NewImageGenerationOpenAIToOpenAITranslator("v1", "", nil) + headers := map[string]string{statusHeaderName: "503"} + hm, bm, err := tr.ResponseError(headers, errReader{}) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to read error body") + require.Nil(t, hm) + require.Nil(t, bm) +} + +func TestOpenAIToOpenAIImageTranslator_ResponseBody_ModelPropagatesFromRequest(t *testing.T) { + // Use override so effective model differs from original + tr := NewImageGenerationOpenAIToOpenAITranslator("v1", "gpt-image-1", nil) + req := &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModelDallE3, Prompt: "a cat"} + original, _ := json.Marshal(req) + // Call RequestBody first to set requestModel inside translator + _, _, err := tr.RequestBody(original, req, false) + require.NoError(t, err) + + resp := &openaisdk.ImagesResponse{ + // Two images returned + Data: make([]openaisdk.Image, 2), + Size: openaisdk.ImagesResponseSize1024x1024, + } + buf, _ := json.Marshal(resp) + _, _, _, respModel, err := tr.ResponseBody(map[string]string{}, bytes.NewReader(buf), true) + require.NoError(t, err) + require.Equal(t, "gpt-image-1", respModel) +} + +func TestOpenAIToOpenAIImageTranslator_ResponseHeaders_NoOp(t *testing.T) { + tr := NewImageGenerationOpenAIToOpenAITranslator("v1", "", nil) + hm, err := tr.ResponseHeaders(map[string]string{"foo": "bar"}) + require.NoError(t, err) + require.Nil(t, hm) +} + +func TestOpenAIToOpenAIImageTranslator_ResponseBody_DecodeError(t *testing.T) { + tr := NewImageGenerationOpenAIToOpenAITranslator("v1", "", nil) + _, _, _, _, err := tr.ResponseBody(map[string]string{}, bytes.NewReader([]byte("not-json")), true) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to decode response body") +} diff --git a/internal/extproc/translator/translator.go b/internal/extproc/translator/translator.go index f1eab6f6f3..22cae92fa2 100644 --- a/internal/extproc/translator/translator.go +++ b/internal/extproc/translator/translator.go @@ -11,6 +11,7 @@ import ( corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + openaisdk "github.com/openai/openai-go/v2" "github.com/tidwall/sjson" anthropicschema "github.com/envoyproxy/ai-gateway/internal/apischema/anthropic" @@ -224,3 +225,45 @@ var sjsonOptions = &sjson.Options{ // it must be ensured that the original body is not modified, i.e. the operation must be idempotent. ReplaceInPlace: false, } + +// ImageGenerationTranslator translates the request and response messages between the client and the backend API schemas +// for /v1/images/generations endpoint of OpenAI. +// +// This is created per request and is not thread-safe. +type ImageGenerationTranslator interface { + // RequestBody translates the request body. + // - raw is the raw request body. + // - body is the request body parsed into the OpenAI SDK [openaisdk.ImageGenerateParams]. + // - forceBodyMutation is true if the translator should always mutate the body, even if no changes are made. + // - This returns headerMutation and bodyMutation that can be nil to indicate no mutation. + RequestBody(raw []byte, body *openaisdk.ImageGenerateParams, forceBodyMutation bool) ( + headerMutation *extprocv3.HeaderMutation, + bodyMutation *extprocv3.BodyMutation, + err error, + ) + + // ResponseHeaders translates the response headers. + // - headers is the response headers. + // - This returns headerMutation that can be nil to indicate no mutation. + ResponseHeaders(headers map[string]string) ( + headerMutation *extprocv3.HeaderMutation, + err error, + ) + + // ResponseBody translates the response body. + // - body is the response body. + // - This returns headerMutation and bodyMutation that can be nil to indicate no mutation. + // - This returns responseModel that is the model name from the response (may differ from request model). + ResponseBody(respHeaders map[string]string, body io.Reader, endOfStream bool) ( + headerMutation *extprocv3.HeaderMutation, + bodyMutation *extprocv3.BodyMutation, + tokenUsage LLMTokenUsage, + responseModel internalapi.ResponseModel, + err error, + ) + + // ResponseError translates the response error. This is called when the upstream response status code is not successful (2xx). + // - respHeaders is the response headers. + // - body is the response body that contains the error message. + ResponseError(respHeaders map[string]string, body io.Reader) (headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, err error) +} diff --git a/internal/extproc/util.go b/internal/extproc/util.go index 651437c0bf..c00c1cac8c 100644 --- a/internal/extproc/util.go +++ b/internal/extproc/util.go @@ -11,6 +11,7 @@ import ( "fmt" "io" + "github.com/andybalholm/brotli" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" ) @@ -21,7 +22,7 @@ type contentDecodingResult struct { } // decodeContentIfNeeded decompresses the response body based on the content-encoding header. -// Currently, supports gzip encoding, but can be extended to support other encodings in the future. +// Currently, supports gzip and brotli encoding, but can be extended to support other encodings in the future. // Returns a reader for the (potentially decompressed) body and metadata about the encoding. func decodeContentIfNeeded(body []byte, contentEncoding string) (contentDecodingResult, error) { switch contentEncoding { @@ -34,6 +35,12 @@ func decodeContentIfNeeded(body []byte, contentEncoding string) (contentDecoding reader: reader, isEncoded: true, }, nil + case "br": + reader := brotli.NewReader(bytes.NewReader(body)) + return contentDecodingResult{ + reader: reader, + isEncoded: true, + }, nil default: return contentDecodingResult{ reader: bytes.NewReader(body), diff --git a/internal/metrics/genai.go b/internal/metrics/genai.go index b3d659aa28..8c04c54cd2 100644 --- a/internal/metrics/genai.go +++ b/internal/metrics/genai.go @@ -24,14 +24,15 @@ const ( genaiAttributeTokenType = "gen_ai.token.type" //nolint:gosec // metric name, not credential genaiAttributeErrorType = "error.type" - genaiOperationChat = "chat" - genaiOperationCompletion = "completion" - genaiOperationEmbedding = "embeddings" - genaiOperationMessages = "messages" - genaiProviderOpenAI = "openai" - genaiProviderAWSBedrock = "aws.bedrock" - genaiTokenTypeInput = "input" - genaiTokenTypeOutput = "output" + genaiOperationChat = "chat" + genaiOperationCompletion = "completion" + genaiOperationEmbedding = "embeddings" + genaiOperationMessages = "messages" + genaiOperationImageGeneration = "image_generation" + genaiProviderOpenAI = "openai" + genaiProviderAWSBedrock = "aws.bedrock" + genaiTokenTypeInput = "input" + genaiTokenTypeOutput = "output" // "cached_input" is not yet part of the spec but has been proposed: // https://github.com/open-telemetry/semantic-conventions/issues/1959 // diff --git a/internal/metrics/image_generation_metrics.go b/internal/metrics/image_generation_metrics.go new file mode 100644 index 0000000000..8106000f61 --- /dev/null +++ b/internal/metrics/image_generation_metrics.go @@ -0,0 +1,107 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package metrics + +import ( + "context" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" +) + +// imageGeneration is the implementation for the image generation AI Gateway metrics. +type imageGeneration struct { + baseMetrics +} + +// ImageGenerationMetrics is the interface for the image generation AI Gateway metrics. +type ImageGenerationMetrics interface { + // StartRequest initializes timing for a new request. + StartRequest(headers map[string]string) + // SetOriginalModel sets the original model from the incoming request body before any virtualization applies. + // This is usually called after parsing the request body. Example: dall-e-3 + SetOriginalModel(originalModel internalapi.OriginalModel) + // SetRequestModel sets the request model name. + SetRequestModel(requestModel internalapi.RequestModel) + // SetResponseModel sets the response model name. + SetResponseModel(responseModel internalapi.ResponseModel) + // SetBackend sets the selected backend when the routing decision has been made. This is usually called + // after parsing the request body to determine the model and invoke the routing logic. + SetBackend(backend *filterapi.Backend) + + // RecordTokenUsage records token usage metrics (image gen typically 0, but supported). + RecordTokenUsage(ctx context.Context, inputTokens, outputTokens uint32, requestHeaderLabelMapping map[string]string) + // RecordRequestCompletion records latency metrics for the entire request. + RecordRequestCompletion(ctx context.Context, success bool, requestHeaderLabelMapping map[string]string) + // RecordImageGeneration records metrics specific to image generation (request duration only). + RecordImageGeneration(ctx context.Context, requestHeaderLabelMapping map[string]string) +} + +// ImageGenerationMetricsFactory is a closure that creates a new ImageGenerationMetrics instance. +type ImageGenerationMetricsFactory func() ImageGenerationMetrics + +// NewImageGenerationFactory returns a closure to create a new ImageGenerationMetrics instance. +func NewImageGenerationFactory(meter metric.Meter, requestHeaderLabelMapping map[string]string) ImageGenerationMetricsFactory { + b := baseMetricsFactory{metrics: newGenAI(meter), requestHeaderAttributeMapping: requestHeaderLabelMapping} + return func() ImageGenerationMetrics { + return &imageGeneration{ + baseMetrics: b.newBaseMetrics(genaiOperationImageGeneration), + } + } +} + +// StartRequest initializes timing for a new request. +func (i *imageGeneration) StartRequest(headers map[string]string) { + i.baseMetrics.StartRequest(headers) +} + +// SetOriginalModel sets the original model from the incoming request body before any virtualization applies. +func (i *imageGeneration) SetOriginalModel(originalModel internalapi.OriginalModel) { + i.baseMetrics.SetOriginalModel(originalModel) +} + +// SetRequestModel sets the request model for the request. +func (i *imageGeneration) SetRequestModel(requestModel internalapi.RequestModel) { + i.baseMetrics.SetRequestModel(requestModel) +} + +// SetResponseModel sets the response model for the request. +func (i *imageGeneration) SetResponseModel(responseModel internalapi.ResponseModel) { + i.baseMetrics.SetResponseModel(responseModel) +} + +// RecordTokenUsage implements [ImageGeneration.RecordTokenUsage]. +func (i *imageGeneration) RecordTokenUsage(ctx context.Context, inputTokens, outputTokens uint32, requestHeaders map[string]string) { + attrs := i.buildBaseAttributes(requestHeaders) + + // For image generation, token usage is typically 0, but we still record it for consistency + i.metrics.tokenUsage.Record(ctx, float64(inputTokens), + metric.WithAttributeSet(attrs), + metric.WithAttributes(attribute.Key(genaiAttributeTokenType).String(genaiTokenTypeInput)), + ) + i.metrics.tokenUsage.Record(ctx, float64(outputTokens), + metric.WithAttributeSet(attrs), + metric.WithAttributes(attribute.Key(genaiAttributeTokenType).String(genaiTokenTypeOutput)), + ) + // Note: We don't record totalTokens separately as it causes double counting. + // The OTEL spec only defines "input" and "output" token types. +} + +// RecordImageGeneration implements [ImageGeneration.RecordImageGeneration]. +func (i *imageGeneration) RecordImageGeneration(ctx context.Context, requestHeaders map[string]string) { + attrs := i.buildBaseAttributes(requestHeaders) + // Record request duration with base attributes only for consistency with other operations/tests. + i.metrics.requestLatency.Record(ctx, time.Since(i.requestStart).Seconds(), metric.WithAttributeSet(attrs)) +} + +// GetTimeToGenerate returns the time taken to generate images. +func (i *imageGeneration) GetTimeToGenerate() time.Duration { + return time.Since(i.requestStart) +} diff --git a/internal/metrics/image_generation_metrics_test.go b/internal/metrics/image_generation_metrics_test.go new file mode 100644 index 0000000000..a4bc509fd8 --- /dev/null +++ b/internal/metrics/image_generation_metrics_test.go @@ -0,0 +1,119 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package metrics + +import ( + "testing" + "testing/synctest" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/sdk/metric" + + "github.com/envoyproxy/ai-gateway/internal/filterapi" +) + +func TestImageGeneration_RecordTokenUsage(t *testing.T) { + // Mirrors chat/embeddings token usage tests, but for image_generation. + var ( + mr = metric.NewManualReader() + meter = metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") + im = NewImageGenerationFactory(meter, nil)().(*imageGeneration) + + attrsBase = []attribute.KeyValue{ + attribute.Key(genaiAttributeOperationName).String(genaiOperationImageGeneration), + attribute.Key(genaiAttributeProviderName).String(genaiProviderOpenAI), + attribute.Key(genaiAttributeOriginalModel).String("test-model"), + attribute.Key(genaiAttributeRequestModel).String("test-model"), + attribute.Key(genaiAttributeResponseModel).String("test-model"), + } + inputAttrs = attribute.NewSet(append(attrsBase, attribute.Key(genaiAttributeTokenType).String(genaiTokenTypeInput))...) + outputAttrs = attribute.NewSet(append(attrsBase, attribute.Key(genaiAttributeTokenType).String(genaiTokenTypeOutput))...) + ) + + // Set labels and record usage. + im.SetOriginalModel("test-model") + im.SetRequestModel("test-model") + im.SetResponseModel("test-model") + im.SetBackend(&filterapi.Backend{Schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}}) + im.RecordTokenUsage(t.Context(), 3, 7, nil) + + count, sum := getHistogramValues(t, mr, genaiMetricClientTokenUsage, inputAttrs) + assert.Equal(t, uint64(1), count) + assert.Equal(t, 3.0, sum) + + count, sum = getHistogramValues(t, mr, genaiMetricClientTokenUsage, outputAttrs) + assert.Equal(t, uint64(1), count) + assert.Equal(t, 7.0, sum) +} + +func TestImageGeneration_RecordImageGeneration(t *testing.T) { + // Use synctest to keep time-based assertions deterministic. + synctest.Test(t, func(t *testing.T) { + mr := metric.NewManualReader() + meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") + im := NewImageGenerationFactory(meter, nil)().(*imageGeneration) + + // Base attributes for request duration metric + baseAttrs := attribute.NewSet( + attribute.Key(genaiAttributeOperationName).String(genaiOperationImageGeneration), + attribute.Key(genaiAttributeProviderName).String(genaiProviderOpenAI), + attribute.Key(genaiAttributeOriginalModel).String("img-model"), + attribute.Key(genaiAttributeRequestModel).String("img-model"), + attribute.Key(genaiAttributeResponseModel).String("img-model"), + ) + + im.StartRequest(nil) + im.SetOriginalModel("img-model") + im.SetRequestModel("img-model") + im.SetResponseModel("img-model") + im.SetBackend(&filterapi.Backend{Schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}}) + + time.Sleep(10 * time.Millisecond) + im.RecordImageGeneration(t.Context(), nil) + + count, sum := getHistogramValues(t, mr, genaiMetricServerRequestDuration, baseAttrs) + assert.Equal(t, uint64(1), count) + assert.Equal(t, 10*time.Millisecond.Seconds(), sum) + }) +} + +func TestImageGeneration_HeaderLabelMapping(t *testing.T) { + // Verify header mapping is honored for token usage metrics. + var ( + mr = metric.NewManualReader() + meter = metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") + headerMapping = map[string]string{"x-user-id": "user_id", "x-org-id": "org_id"} + im = NewImageGenerationFactory(meter, headerMapping)().(*imageGeneration) + ) + + requestHeaders := map[string]string{ + "x-user-id": "user123", + "x-org-id": "org456", + } + + im.SetOriginalModel("test-model") + im.SetRequestModel("test-model") + im.SetResponseModel("test-model") + im.SetBackend(&filterapi.Backend{Schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}}) + im.RecordTokenUsage(t.Context(), 5, 0, requestHeaders) + + attrs := attribute.NewSet( + attribute.Key(genaiAttributeOperationName).String(genaiOperationImageGeneration), + attribute.Key(genaiAttributeProviderName).String(genaiProviderOpenAI), + attribute.Key(genaiAttributeOriginalModel).String("test-model"), + attribute.Key(genaiAttributeRequestModel).String("test-model"), + attribute.Key(genaiAttributeResponseModel).String("test-model"), + attribute.Key(genaiAttributeTokenType).String(genaiTokenTypeInput), + attribute.Key("user_id").String("user123"), + attribute.Key("org_id").String("org456"), + ) + + count, _ := getHistogramValues(t, mr, genaiMetricClientTokenUsage, attrs) + require.Equal(t, uint64(1), count) +} diff --git a/internal/tracing/api/api.go b/internal/tracing/api/api.go index d623c33f62..ad4e3caa20 100644 --- a/internal/tracing/api/api.go +++ b/internal/tracing/api/api.go @@ -11,6 +11,7 @@ import ( "context" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + openaisdk "github.com/openai/openai-go/v2" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" @@ -20,10 +21,12 @@ import ( var _ Tracing = NoopTracing{} // Tracing gives access to tracer types needed for endpoints such as OpenAI -// chat completions. +// chat completions, image generation, embeddings, and MCP requests. type Tracing interface { // ChatCompletionTracer creates spans for OpenAI chat completion requests on /chat/completions endpoint. ChatCompletionTracer() ChatCompletionTracer + // ImageGenerationTracer creates spans for OpenAI image generation requests. + ImageGenerationTracer() ImageGenerationTracer // CompletionTracer creates spans for OpenAI completion requests on /completions endpoint. CompletionTracer() CompletionTracer // EmbeddingsTracer creates spans for OpenAI embeddings requests on /embeddings endpoint. @@ -38,11 +41,12 @@ type Tracing interface { // // Implementations of the Tracing interface. type TracingConfig struct { - Tracer trace.Tracer - Propagator propagation.TextMapPropagator - ChatCompletionRecorder ChatCompletionRecorder - CompletionRecorder CompletionRecorder - EmbeddingsRecorder EmbeddingsRecorder + Tracer trace.Tracer + Propagator propagation.TextMapPropagator + ChatCompletionRecorder ChatCompletionRecorder + CompletionRecorder CompletionRecorder + ImageGenerationRecorder ImageGenerationRecorder + EmbeddingsRecorder EmbeddingsRecorder } // NoopTracing is a Tracing that doesn't do anything. @@ -67,6 +71,11 @@ func (NoopTracing) EmbeddingsTracer() EmbeddingsTracer { return NoopEmbeddingsTracer{} } +// ImageGenerationTracer implements Tracing.ImageGenerationTracer. +func (NoopTracing) ImageGenerationTracer() ImageGenerationTracer { + return NoopImageGenerationTracer{} +} + // Shutdown implements Tracing.Shutdown. func (NoopTracing) Shutdown(context.Context) error { return nil @@ -244,6 +253,61 @@ type EmbeddingsSpan interface { EndSpan() } +// ImageGenerationTracer creates spans for OpenAI image generation requests. +type ImageGenerationTracer interface { + // StartSpanAndInjectHeaders starts a span and injects trace context into + // the header mutation. + // + // Parameters: + // - ctx: might include a parent span context. + // - headers: Incoming HTTP headers used to extract parent trace context. + // - headerMutation: The new LLM Span will have its context written to + // these headers unless NoopTracer is used. + // - req: The OpenAI image generation request. Used to record request attributes. + // + // Returns nil unless the span is sampled. + StartSpanAndInjectHeaders(ctx context.Context, headers map[string]string, headerMutation *extprocv3.HeaderMutation, req *openaisdk.ImageGenerateParams, body []byte) ImageGenerationSpan +} + +// ImageGenerationSpan represents an OpenAI image generation. +type ImageGenerationSpan interface { + // RecordResponse records the response attributes to the span. + RecordResponse(resp *openaisdk.ImagesResponse) + + // EndSpanOnError finalizes and ends the span with an error status. + EndSpanOnError(statusCode int, body []byte) + + // EndSpan finalizes and ends the span. + EndSpan() +} + +// ImageGenerationRecorder records attributes to a span according to a semantic +// convention. +type ImageGenerationRecorder interface { + // StartParams returns the name and options to start the span with. + // + // Parameters: + // - req: contains the image generation request + // - body: contains the complete request body. + // + // Note: Do not do any expensive data conversions as the span might not be + // sampled. + StartParams(req *openaisdk.ImageGenerateParams, body []byte) (spanName string, opts []trace.SpanStartOption) + + // RecordRequest records request attributes to the span. + // + // Parameters: + // - req: contains the image generation request + // - body: contains the complete request body. + RecordRequest(span trace.Span, req *openaisdk.ImageGenerateParams, body []byte) + + // RecordResponse records response attributes to the span. + RecordResponse(span trace.Span, resp *openaisdk.ImagesResponse) + + // RecordResponseOnError ends recording the span with an error status. + RecordResponseOnError(span trace.Span, statusCode int, body []byte) +} + // EmbeddingsRecorder records attributes to a span according to a semantic // convention. type EmbeddingsRecorder interface { @@ -271,6 +335,14 @@ type EmbeddingsRecorder interface { RecordResponseOnError(span trace.Span, statusCode int, body []byte) } +// NoopImageGenerationTracer is a ImageGenerationTracer that doesn't do anything. +type NoopImageGenerationTracer struct{} + +// StartSpanAndInjectHeaders implements ImageGenerationTracer.StartSpanAndInjectHeaders. +func (NoopImageGenerationTracer) StartSpanAndInjectHeaders(context.Context, map[string]string, *extprocv3.HeaderMutation, *openaisdk.ImageGenerateParams, []byte) ImageGenerationSpan { + return nil +} + // NoopEmbeddingsTracer is an EmbeddingsTracer that doesn't do anything. type NoopEmbeddingsTracer struct{} diff --git a/internal/tracing/image_generation_span.go b/internal/tracing/image_generation_span.go new file mode 100644 index 0000000000..c0a45239b6 --- /dev/null +++ b/internal/tracing/image_generation_span.go @@ -0,0 +1,37 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package tracing + +import ( + openaisdk "github.com/openai/openai-go/v2" + "go.opentelemetry.io/otel/trace" + + tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" +) + +// Ensure imageGenerationSpan implements ImageGenerationSpan. +var _ tracing.ImageGenerationSpan = (*imageGenerationSpan)(nil) + +type imageGenerationSpan struct { + span trace.Span + recorder tracing.ImageGenerationRecorder +} + +// RecordResponse invokes [tracing.ImageGenerationRecorder.RecordResponse]. +func (s *imageGenerationSpan) RecordResponse(resp *openaisdk.ImagesResponse) { + s.recorder.RecordResponse(s.span, resp) +} + +// EndSpan invokes [tracing.ImageGenerationRecorder.RecordResponse] and ends the span. +func (s *imageGenerationSpan) EndSpan() { + s.span.End() +} + +// EndSpanOnError invokes [tracing.ImageGenerationRecorder.RecordResponseOnError] and ends the span. +func (s *imageGenerationSpan) EndSpanOnError(statusCode int, body []byte) { + s.recorder.RecordResponseOnError(s.span, statusCode, body) + s.span.End() +} diff --git a/internal/tracing/image_generation_span_test.go b/internal/tracing/image_generation_span_test.go new file mode 100644 index 0000000000..6ac77fc3f9 --- /dev/null +++ b/internal/tracing/image_generation_span_test.go @@ -0,0 +1,171 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package tracing + +import ( + "encoding/json" + "testing" + + openaisdk "github.com/openai/openai-go/v2" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/envoyproxy/ai-gateway/internal/testing/testotel" +) + +// Test data for image generation span tests + +// Mock recorder for testing image generation span +type testImageGenerationRecorder struct{} + +func (r testImageGenerationRecorder) StartParams(_ *openaisdk.ImageGenerateParams, _ []byte) (string, []oteltrace.SpanStartOption) { + return "ImageGeneration", nil +} + +func (r testImageGenerationRecorder) RecordRequest(span oteltrace.Span, req *openaisdk.ImageGenerateParams, _ []byte) { + span.SetAttributes( + attribute.String("model", req.Model), + attribute.String("prompt", req.Prompt), + attribute.String("size", string(req.Size)), + ) +} + +func (r testImageGenerationRecorder) RecordResponse(span oteltrace.Span, resp *openaisdk.ImagesResponse) { + respBytes, _ := json.Marshal(resp) + span.SetAttributes( + attribute.Int("statusCode", 200), + attribute.Int("respBodyLen", len(respBytes)), + ) +} + +func (r testImageGenerationRecorder) RecordResponseOnError(span oteltrace.Span, statusCode int, body []byte) { + span.SetAttributes( + attribute.Int("statusCode", statusCode), + attribute.String("errorBody", string(body)), + ) +} + +func TestImageGenerationSpan_RecordResponse(t *testing.T) { + resp := &openaisdk.ImagesResponse{ + Data: []openaisdk.Image{{URL: "https://example.com/test.png"}}, + Size: openaisdk.ImagesResponseSize1024x1024, + Usage: openaisdk.ImagesResponseUsage{ + InputTokens: 5, + OutputTokens: 100, + TotalTokens: 105, + }, + } + respBytes, err := json.Marshal(resp) + require.NoError(t, err) + + s := &imageGenerationSpan{recorder: testImageGenerationRecorder{}} + actualSpan := testotel.RecordWithSpan(t, func(span oteltrace.Span) bool { + s.span = span + s.RecordResponse(resp) + return false // Recording response shouldn't end the span. + }) + + require.Equal(t, []attribute.KeyValue{ + attribute.Int("statusCode", 200), + attribute.Int("respBodyLen", len(respBytes)), + }, actualSpan.Attributes) +} + +func TestImageGenerationSpan_EndSpan(t *testing.T) { + s := &imageGenerationSpan{recorder: testImageGenerationRecorder{}} + actualSpan := testotel.RecordWithSpan(t, func(span oteltrace.Span) bool { + s.span = span + s.EndSpan() + return true // EndSpan ends the underlying span. + }) + + // EndSpan should not add any attributes, just end the span + require.Empty(t, actualSpan.Attributes) +} + +func TestImageGenerationSpan_EndSpanOnError(t *testing.T) { + errorMsg := "image generation failed" + actualSpan := testotel.RecordWithSpan(t, func(span oteltrace.Span) bool { + s := &imageGenerationSpan{span: span, recorder: testImageGenerationRecorder{}} + s.EndSpanOnError(500, []byte(errorMsg)) + return true // EndSpanOnError ends the underlying span. + }) + + require.Equal(t, []attribute.KeyValue{ + attribute.Int("statusCode", 500), + attribute.String("errorBody", errorMsg), + }, actualSpan.Attributes) +} + +func TestImageGenerationSpan_RecordResponse_WithMultipleImages(t *testing.T) { + resp := &openaisdk.ImagesResponse{ + Data: []openaisdk.Image{ + {URL: "https://example.com/img1.png"}, + {URL: "https://example.com/img2.png"}, + {URL: "https://example.com/img3.png"}, + }, + Size: openaisdk.ImagesResponseSize1024x1024, + Usage: openaisdk.ImagesResponseUsage{ + InputTokens: 10, + OutputTokens: 200, + TotalTokens: 210, + }, + } + respBytes, err := json.Marshal(resp) + require.NoError(t, err) + + s := &imageGenerationSpan{recorder: testImageGenerationRecorder{}} + actualSpan := testotel.RecordWithSpan(t, func(span oteltrace.Span) bool { + s.span = span + s.RecordResponse(resp) + return false + }) + + require.Equal(t, []attribute.KeyValue{ + attribute.Int("statusCode", 200), + attribute.Int("respBodyLen", len(respBytes)), + }, actualSpan.Attributes) +} + +func TestImageGenerationSpan_EndSpanOnError_WithDifferentStatusCodes(t *testing.T) { + tests := []struct { + name string + statusCode int + errorBody string + }{ + { + name: "bad request", + statusCode: 400, + errorBody: `{"error":{"message":"Invalid prompt","type":"invalid_request_error"}}`, + }, + { + name: "rate limit", + statusCode: 429, + errorBody: `{"error":{"message":"Rate limit exceeded","type":"rate_limit_error"}}`, + }, + { + name: "server error", + statusCode: 500, + errorBody: `{"error":{"message":"Internal server error","type":"server_error"}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actualSpan := testotel.RecordWithSpan(t, func(span oteltrace.Span) bool { + s := &imageGenerationSpan{span: span, recorder: testImageGenerationRecorder{}} + s.EndSpanOnError(tt.statusCode, []byte(tt.errorBody)) + return true + }) + + require.Equal(t, []attribute.KeyValue{ + attribute.Int("statusCode", tt.statusCode), + attribute.String("errorBody", tt.errorBody), + }, actualSpan.Attributes) + }) + } +} diff --git a/internal/tracing/image_generation_tracer.go b/internal/tracing/image_generation_tracer.go new file mode 100644 index 0000000000..ddb8d6c7de --- /dev/null +++ b/internal/tracing/image_generation_tracer.go @@ -0,0 +1,62 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package tracing + +import ( + "context" + + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + openaisdk "github.com/openai/openai-go/v2" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" + + tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" +) + +// Ensure imageGenerationTracer implements ImageGenerationTracer. +var _ tracing.ImageGenerationTracer = (*imageGenerationTracer)(nil) + +func newImageGenerationTracer(tracer trace.Tracer, propagator propagation.TextMapPropagator, recorder tracing.ImageGenerationRecorder) tracing.ImageGenerationTracer { + // Check if the tracer is a no-op by checking its type. + if _, ok := tracer.(noop.Tracer); ok { + return tracing.NoopImageGenerationTracer{} + } + return &imageGenerationTracer{ + tracer: tracer, + propagator: propagator, + recorder: recorder, + } +} + +type imageGenerationTracer struct { + tracer trace.Tracer + recorder tracing.ImageGenerationRecorder + propagator propagation.TextMapPropagator +} + +// StartSpanAndInjectHeaders implements ImageGenerationTracer.StartSpanAndInjectHeaders. +func (t *imageGenerationTracer) StartSpanAndInjectHeaders(ctx context.Context, headers map[string]string, mutableHeaders *extprocv3.HeaderMutation, req *openaisdk.ImageGenerateParams, body []byte) tracing.ImageGenerationSpan { + // Extract trace context from incoming headers. + parentCtx := t.propagator.Extract(ctx, propagation.MapCarrier(headers)) + + // Start the span with options appropriate for the semantic convention. + spanName, opts := t.recorder.StartParams(req, body) + newCtx, span := t.tracer.Start(parentCtx, spanName, opts...) + + // Always inject trace context into the header mutation if provided. + // This ensures trace propagation works even for unsampled spans. + t.propagator.Inject(newCtx, &headerMutationCarrier{m: mutableHeaders}) + + // Only record request attributes if span is recording (sampled). + // This avoids expensive body processing for unsampled spans. + if span.IsRecording() { + t.recorder.RecordRequest(span, req, body) + return &imageGenerationSpan{span: span, recorder: t.recorder} + } + + return nil +} diff --git a/internal/tracing/image_generation_tracer_test.go b/internal/tracing/image_generation_tracer_test.go new file mode 100644 index 0000000000..1a5607d9d2 --- /dev/null +++ b/internal/tracing/image_generation_tracer_test.go @@ -0,0 +1,400 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package tracing + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + openaisdk "github.com/openai/openai-go/v2" + openaiparam "github.com/openai/openai-go/v2/packages/param" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/contrib/propagators/autoprop" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + oteltrace "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" + + tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" +) + +var ( + imageGenStartOpts = []oteltrace.SpanStartOption{oteltrace.WithSpanKind(oteltrace.SpanKindServer)} + + imageGenReq = &openaisdk.ImageGenerateParams{ + Model: openaisdk.ImageModelGPTImage1, + Prompt: "a beautiful sunset over mountains", + Size: openaisdk.ImageGenerateParamsSize1024x1024, + Quality: openaisdk.ImageGenerateParamsQualityHigh, + ResponseFormat: openaisdk.ImageGenerateParamsResponseFormatURL, + N: openaiparam.NewOpt[int64](1), + } +) + +func TestImageGenerationTracer_StartSpanAndInjectHeaders(t *testing.T) { + respBody := &openaisdk.ImagesResponse{ + Data: []openaisdk.Image{ + {URL: "https://example.com/generated-image.png"}, + }, + Size: openaisdk.ImagesResponseSize1024x1024, + Usage: openaisdk.ImagesResponseUsage{ + InputTokens: 8, + OutputTokens: 1056, + TotalTokens: 1064, + }, + } + respBodyBytes, err := json.Marshal(respBody) + require.NoError(t, err) + bodyLen := len(respBodyBytes) + + reqBody, err := json.Marshal(req) + require.NoError(t, err) + reqBodyLen := len(reqBody) + + tests := []struct { + name string + req *openaisdk.ImageGenerateParams + existingHeaders map[string]string + expectedSpanName string + expectedAttrs []attribute.KeyValue + expectedTraceID string + }{ + { + name: "basic image generation request", + req: imageGenReq, + existingHeaders: map[string]string{}, + expectedSpanName: "ImageGeneration", + expectedAttrs: []attribute.KeyValue{ + attribute.String("model", imageGenReq.Model), + attribute.String("prompt", imageGenReq.Prompt), + attribute.String("size", string(imageGenReq.Size)), + attribute.String("quality", string(imageGenReq.Quality)), + attribute.String("response_format", string(imageGenReq.ResponseFormat)), + attribute.String("n", "1"), + attribute.Int("reqBodyLen", reqBodyLen), + attribute.Int("statusCode", 200), + attribute.Int("respBodyLen", bodyLen), + }, + }, + { + name: "with existing trace context", + req: imageGenReq, + existingHeaders: map[string]string{ + "traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", + }, + expectedSpanName: "ImageGeneration", + expectedAttrs: []attribute.KeyValue{ + attribute.String("model", imageGenReq.Model), + attribute.String("prompt", imageGenReq.Prompt), + attribute.String("size", string(imageGenReq.Size)), + attribute.String("quality", string(imageGenReq.Quality)), + attribute.String("response_format", string(imageGenReq.ResponseFormat)), + attribute.String("n", "1"), + attribute.Int("reqBodyLen", reqBodyLen), + attribute.Int("statusCode", 200), + attribute.Int("respBodyLen", bodyLen), + }, + expectedTraceID: "4bf92f3577b34da6a3ce929d0e0e4736", + }, + { + name: "multiple images request", + req: &openaisdk.ImageGenerateParams{ + Model: openaisdk.ImageModelGPTImage1, + Prompt: "a cat and a dog", + Size: openaisdk.ImageGenerateParamsSize512x512, + Quality: openaisdk.ImageGenerateParamsQualityStandard, + ResponseFormat: openaisdk.ImageGenerateParamsResponseFormatB64JSON, + N: openaiparam.NewOpt[int64](2), + }, + existingHeaders: map[string]string{}, + expectedSpanName: "ImageGeneration", + expectedAttrs: []attribute.KeyValue{ + attribute.String("model", openaisdk.ImageModelGPTImage1), + attribute.String("prompt", "a cat and a dog"), + attribute.String("size", string(openaisdk.ImageGenerateParamsSize512x512)), + attribute.String("quality", string(openaisdk.ImageGenerateParamsQualityStandard)), + attribute.String("response_format", string(openaisdk.ImageGenerateParamsResponseFormatB64JSON)), + attribute.String("n", "2"), + attribute.Int("reqBodyLen", 0), // Will be calculated in test + attribute.Int("statusCode", 200), + attribute.Int("respBodyLen", bodyLen), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + exporter := tracetest.NewInMemoryExporter() + tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) + + tracer := newImageGenerationTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator(), testImageGenTracerRecorder{}) + + headerMutation := &extprocv3.HeaderMutation{} + reqBody, err := json.Marshal(tt.req) + require.NoError(t, err) + + // Update expected attributes with actual request body length + expectedAttrs := make([]attribute.KeyValue, len(tt.expectedAttrs)) + copy(expectedAttrs, tt.expectedAttrs) + for i, attr := range expectedAttrs { + if attr.Key == "reqBodyLen" { + expectedAttrs[i] = attribute.Int("reqBodyLen", len(reqBody)) + break + } + } + + span := tracer.StartSpanAndInjectHeaders(t.Context(), + tt.existingHeaders, + headerMutation, + tt.req, + reqBody, + ) + require.IsType(t, &imageGenerationSpan{}, span) + + // End the span to export it. + span.RecordResponse(respBody) + span.EndSpan() + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + actualSpan := spans[0] + + // Check span state. + require.Equal(t, tt.expectedSpanName, actualSpan.Name) + require.Equal(t, expectedAttrs, actualSpan.Attributes) + require.Empty(t, actualSpan.Events) + + // Check header mutation. + traceID := actualSpan.SpanContext.TraceID().String() + if tt.expectedTraceID != "" { + require.Equal(t, tt.expectedTraceID, actualSpan.SpanContext.TraceID().String()) + } + spanID := actualSpan.SpanContext.SpanID().String() + require.Equal(t, &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + { + Header: &corev3.HeaderValue{ + Key: "traceparent", + RawValue: []byte("00-" + traceID + "-" + spanID + "-01"), + }, + }, + }, + }, headerMutation) + }) + } +} + +func TestNewImageGenerationTracer_Noop(t *testing.T) { + // Use noop tracer. + noopTracer := noop.Tracer{} + + tracer := newImageGenerationTracer(noopTracer, autoprop.NewTextMapPropagator(), testImageGenTracerRecorder{}) + + // Verify it returns NoopTracer. + require.IsType(t, tracing.NoopImageGenerationTracer{}, tracer) + + // Test that noop tracer doesn't create spans. + headers := map[string]string{} + headerMutation := &extprocv3.HeaderMutation{} + req := &openaisdk.ImageGenerateParams{ + Model: openaisdk.ImageModelGPTImage1, + Prompt: "test prompt", + } + + span := tracer.StartSpanAndInjectHeaders(t.Context(), + headers, + headerMutation, + req, + []byte("{}"), + ) + + require.Nil(t, span) + + // Verify no headers were injected. + require.Empty(t, headerMutation.SetHeaders) +} + +func TestImageGenerationTracer_UnsampledSpan(t *testing.T) { + // Use always_off sampler to ensure spans are not sampled. + tracerProvider := trace.NewTracerProvider( + trace.WithSampler(trace.NeverSample()), + ) + t.Cleanup(func() { _ = tracerProvider.Shutdown(context.Background()) }) + + tracer := newImageGenerationTracer(tracerProvider.Tracer("test"), autoprop.NewTextMapPropagator(), testImageGenTracerRecorder{}) + + // Start a span that won't be sampled. + headers := map[string]string{} + headerMutation := &extprocv3.HeaderMutation{} + req := &openaisdk.ImageGenerateParams{ + Model: openaisdk.ImageModelGPTImage1, + Prompt: "test prompt", + } + + span := tracer.StartSpanAndInjectHeaders(t.Context(), + headers, + headerMutation, + req, + []byte("{}"), + ) + + // Span should be nil when not sampled. + require.Nil(t, span) + + // Headers should still be injected for trace propagation. + require.NotEmpty(t, headerMutation.SetHeaders) +} + +func TestImageGenerationTracer_ErrorHandling(t *testing.T) { + exporter := tracetest.NewInMemoryExporter() + tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) + + tracer := newImageGenerationTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator(), testImageGenTracerRecorder{}) + + headerMutation := &extprocv3.HeaderMutation{} + reqBody, err := json.Marshal(imageGenReq) + require.NoError(t, err) + + span := tracer.StartSpanAndInjectHeaders(t.Context(), + map[string]string{}, + headerMutation, + imageGenReq, + reqBody, + ) + require.IsType(t, &imageGenerationSpan{}, span) + + // Test error handling + errorBody := []byte(`{"error":{"message":"Invalid request","type":"invalid_request_error"}}`) + span.EndSpanOnError(400, errorBody) + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + actualSpan := spans[0] + + // Check that error attributes are recorded + expectedAttrs := []attribute.KeyValue{ + attribute.String("model", imageGenReq.Model), + attribute.String("prompt", imageGenReq.Prompt), + attribute.String("size", string(imageGenReq.Size)), + attribute.String("quality", string(imageGenReq.Quality)), + attribute.String("response_format", string(imageGenReq.ResponseFormat)), + attribute.String("n", "1"), + attribute.Int("reqBodyLen", len(reqBody)), + attribute.Int("statusCode", 400), + attribute.String("errorBody", string(errorBody)), + } + + require.Equal(t, expectedAttrs, actualSpan.Attributes) +} + +func TestImageGenerationTracer_MultipleImagesResponse(t *testing.T) { + exporter := tracetest.NewInMemoryExporter() + tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) + + tracer := newImageGenerationTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator(), testImageGenTracerRecorder{}) + + headerMutation := &extprocv3.HeaderMutation{} + reqBody, err := json.Marshal(imageGenReq) + require.NoError(t, err) + + span := tracer.StartSpanAndInjectHeaders(t.Context(), + map[string]string{}, + headerMutation, + imageGenReq, + reqBody, + ) + require.IsType(t, &imageGenerationSpan{}, span) + + // Test with multiple images response + multiImageResp := &openaisdk.ImagesResponse{ + Data: []openaisdk.Image{ + {URL: "https://example.com/img1.png"}, + {URL: "https://example.com/img2.png"}, + {URL: "https://example.com/img3.png"}, + }, + Size: openaisdk.ImagesResponseSize1024x1024, + Usage: openaisdk.ImagesResponseUsage{ + InputTokens: 10, + OutputTokens: 200, + TotalTokens: 210, + }, + } + + span.RecordResponse(multiImageResp) + span.EndSpan() + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + actualSpan := spans[0] + + // Check that image count is recorded correctly + expectedAttrs := []attribute.KeyValue{ + attribute.String("model", imageGenReq.Model), + attribute.String("prompt", imageGenReq.Prompt), + attribute.String("size", string(imageGenReq.Size)), + attribute.String("quality", string(imageGenReq.Quality)), + attribute.String("response_format", string(imageGenReq.ResponseFormat)), + attribute.String("n", "1"), + attribute.Int("reqBodyLen", len(reqBody)), + attribute.Int("statusCode", 200), + attribute.Int("respBodyLen", 0), // Will be calculated + + } + + // Update respBodyLen with actual value + for i, attr := range expectedAttrs { + if attr.Key == "respBodyLen" { + respBytes, _ := json.Marshal(multiImageResp) + expectedAttrs[i] = attribute.Int("respBodyLen", len(respBytes)) + break + } + } + + require.Equal(t, expectedAttrs, actualSpan.Attributes) +} + +var _ tracing.ImageGenerationRecorder = testImageGenTracerRecorder{} + +type testImageGenTracerRecorder struct{} + +func (r testImageGenTracerRecorder) StartParams(_ *openaisdk.ImageGenerateParams, _ []byte) (spanName string, opts []oteltrace.SpanStartOption) { + return "ImageGeneration", imageGenStartOpts +} + +func (r testImageGenTracerRecorder) RecordRequest(span oteltrace.Span, req *openaisdk.ImageGenerateParams, body []byte) { + n := int64(1) + if req.N.Valid() { + n = req.N.Value + } + span.SetAttributes( + attribute.String("model", req.Model), + attribute.String("prompt", req.Prompt), + attribute.String("size", string(req.Size)), + attribute.String("quality", string(req.Quality)), + attribute.String("response_format", string(req.ResponseFormat)), + attribute.String("n", fmt.Sprintf("%d", n)), + attribute.Int("reqBodyLen", len(body)), + ) +} + +func (r testImageGenTracerRecorder) RecordResponse(span oteltrace.Span, resp *openaisdk.ImagesResponse) { + span.SetAttributes(attribute.Int("statusCode", 200)) + body, err := json.Marshal(resp) + if err != nil { + panic(err) + } + span.SetAttributes(attribute.Int("respBodyLen", len(body))) +} + +func (r testImageGenTracerRecorder) RecordResponseOnError(span oteltrace.Span, statusCode int, body []byte) { + span.SetAttributes(attribute.Int("statusCode", statusCode)) + span.SetAttributes(attribute.String("errorBody", string(body))) +} diff --git a/internal/tracing/openinference/openai/image_generation.go b/internal/tracing/openinference/openai/image_generation.go new file mode 100644 index 0000000000..92b03582e2 --- /dev/null +++ b/internal/tracing/openinference/openai/image_generation.go @@ -0,0 +1,113 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +// Package openai provides OpenInference semantic conventions hooks for +// OpenAI instrumentation used by the ExtProc router filter. +package openai + +import ( + "encoding/json" + + openaisdk "github.com/openai/openai-go/v2" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + + tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" + "github.com/envoyproxy/ai-gateway/internal/tracing/openinference" +) + +// ImageGenerationRecorder implements recorders for OpenInference image generation spans. +type ImageGenerationRecorder struct { + traceConfig *openinference.TraceConfig +} + +// NewImageGenerationRecorderFromEnv creates an api.ImageGenerationRecorder +// from environment variables using the OpenInference configuration specification. +// +// See: https://github.com/Arize-ai/openinference/blob/main/spec/configuration.md +func NewImageGenerationRecorderFromEnv() tracing.ImageGenerationRecorder { + return NewImageGenerationRecorder(nil) +} + +// NewImageGenerationRecorder creates a tracing.ImageGenerationRecorder with the +// given config using the OpenInference configuration specification. +// +// Parameters: +// - config: configuration for redaction. Defaults to NewTraceConfigFromEnv(). +// +// See: https://github.com/Arize-ai/openinference/blob/main/spec/configuration.md +func NewImageGenerationRecorder(config *openinference.TraceConfig) tracing.ImageGenerationRecorder { + if config == nil { + config = openinference.NewTraceConfigFromEnv() + } + return &ImageGenerationRecorder{traceConfig: config} +} + +// startOpts sets trace.SpanKindInternal as that's the span kind used in +// OpenInference. +var imageGenStartOpts = []trace.SpanStartOption{trace.WithSpanKind(trace.SpanKindInternal)} + +// StartParams implements the same method as defined in tracing.ImageGenerationRecorder. +func (r *ImageGenerationRecorder) StartParams(*openaisdk.ImageGenerateParams, []byte) (spanName string, opts []trace.SpanStartOption) { + return "ImageGeneration", imageGenStartOpts +} + +// RecordRequest implements the same method as defined in tracing.ImageGenerationRecorder. +func (r *ImageGenerationRecorder) RecordRequest(span trace.Span, req *openaisdk.ImageGenerateParams, body []byte) { + span.SetAttributes(buildImageGenerationRequestAttributes(req, string(body), r.traceConfig)...) +} + +// RecordResponse implements the same method as defined in tracing.ImageGenerationRecorder. +func (r *ImageGenerationRecorder) RecordResponse(span trace.Span, resp *openaisdk.ImagesResponse) { + // Set output attributes. + var attrs []attribute.KeyValue + attrs = buildImageGenerationResponseAttributes(resp, r.traceConfig) + + bodyString := openinference.RedactedValue + if !r.traceConfig.HideOutputs { + marshaled, err := json.Marshal(resp) + if err == nil { + bodyString = string(marshaled) + } + } + // Match ChatCompletion recorder: include output MIME type and value + attrs = append(attrs, attribute.String(openinference.OutputMimeType, openinference.MimeTypeJSON)) + attrs = append(attrs, attribute.String(openinference.OutputValue, bodyString)) + span.SetAttributes(attrs...) + span.SetStatus(codes.Ok, "") +} + +// RecordResponseOnError implements the same method as defined in tracing.ImageGenerationRecorder. +func (r *ImageGenerationRecorder) RecordResponseOnError(span trace.Span, statusCode int, body []byte) { + recordResponseError(span, statusCode, string(body)) +} + +// buildImageGenerationRequestAttributes builds OpenInference attributes from the image generation request. +func buildImageGenerationRequestAttributes(req *openaisdk.ImageGenerateParams, body string, config *openinference.TraceConfig) []attribute.KeyValue { + attrs := []attribute.KeyValue{ + attribute.String(openinference.SpanKind, openinference.SpanKindLLM), + attribute.String(openinference.LLMSystem, openinference.LLMSystemOpenAI), + attribute.String(openinference.LLMModelName, req.Model), + } + + if config.HideInputs { + attrs = append(attrs, attribute.String(openinference.InputValue, openinference.RedactedValue)) + } else { + attrs = append(attrs, attribute.String(openinference.InputValue, body)) + attrs = append(attrs, attribute.String(openinference.InputMimeType, openinference.MimeTypeJSON)) + } + + return attrs +} + +// buildImageGenerationResponseAttributes builds OpenInference attributes from the image generation response. +func buildImageGenerationResponseAttributes(_ *openaisdk.ImagesResponse, _ *openinference.TraceConfig) []attribute.KeyValue { + attrs := []attribute.KeyValue{} + + // No image-specific response attributes + + return attrs +} diff --git a/internal/tracing/openinference/openai/image_generation_config_test.go b/internal/tracing/openinference/openai/image_generation_config_test.go new file mode 100644 index 0000000000..b3e7f5ab73 --- /dev/null +++ b/internal/tracing/openinference/openai/image_generation_config_test.go @@ -0,0 +1,132 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package openai + +import ( + "encoding/json" + "testing" + + openaisdk "github.com/openai/openai-go/v2" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/sdk/trace" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/envoyproxy/ai-gateway/internal/testing/testotel" + "github.com/envoyproxy/ai-gateway/internal/tracing/openinference" +) + +func TestImageGenerationRecorder_WithConfig_HideInputs(t *testing.T) { + req := basicImageReq + reqBody := basicImageReqBody + + tests := []struct { + name string + config *openinference.TraceConfig + expectedAttrs []attribute.KeyValue + }{ + { + name: "hide input value", + config: &openinference.TraceConfig{ + HideInputs: true, + }, + expectedAttrs: []attribute.KeyValue{ + attribute.String(openinference.SpanKind, openinference.SpanKindLLM), + attribute.String(openinference.LLMSystem, openinference.LLMSystemOpenAI), + attribute.String(openinference.LLMModelName, req.Model), + attribute.String(openinference.InputValue, openinference.RedactedValue), + // No InputMimeType when input is hidden. + }, + }, + { + name: "show input value by default", + config: &openinference.TraceConfig{}, + expectedAttrs: []attribute.KeyValue{ + attribute.String(openinference.SpanKind, openinference.SpanKindLLM), + attribute.String(openinference.LLMSystem, openinference.LLMSystemOpenAI), + attribute.String(openinference.LLMModelName, req.Model), + attribute.String(openinference.InputValue, string(reqBody)), + attribute.String(openinference.InputMimeType, openinference.MimeTypeJSON), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := NewImageGenerationRecorder(tt.config) + + actualSpan := testotel.RecordWithSpan(t, func(span oteltrace.Span) bool { + recorder.RecordRequest(span, req, reqBody) + return false + }) + + attrs := attributesToMap(actualSpan.Attributes) + // Required base attrs + _, hasKind := attrs[openinference.SpanKind] + _, hasSystem := attrs[openinference.LLMSystem] + _, hasModel := attrs[openinference.LLMModelName] + require.True(t, hasKind && hasSystem && hasModel) + + if tt.config.HideInputs { + require.Equal(t, openinference.RedactedValue, attrs[openinference.InputValue]) + _, hasMime := attrs[openinference.InputMimeType] + require.False(t, hasMime) + } else { + require.Equal(t, string(reqBody), attrs[openinference.InputValue]) + require.Equal(t, openinference.MimeTypeJSON, attrs[openinference.InputMimeType]) + } + }) + } +} + +func TestImageGenerationRecorder_WithConfig_HideOutputs(t *testing.T) { + resp := &openaisdk.ImagesResponse{Data: []openaisdk.Image{{URL: "https://example.com/img.png"}}} + respBody, err := json.Marshal(resp) + require.NoError(t, err) + + tests := []struct { + name string + config *openinference.TraceConfig + expectedStatus trace.Status + }{ + { + name: "hide output value", + config: &openinference.TraceConfig{ + HideOutputs: true, + }, + expectedStatus: trace.Status{Code: codes.Ok, Description: ""}, + }, + { + name: "show output value", + config: &openinference.TraceConfig{}, + expectedStatus: trace.Status{Code: codes.Ok, Description: ""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := NewImageGenerationRecorder(tt.config) + + actualSpan := testotel.RecordWithSpan(t, func(span oteltrace.Span) bool { + var r openaisdk.ImagesResponse + require.NoError(t, json.Unmarshal(respBody, &r)) + recorder.RecordResponse(span, &r) + return false + }) + + attrs := attributesToMap(actualSpan.Attributes) + // Output MIME type should be set regardless + require.Equal(t, openinference.MimeTypeJSON, attrs[openinference.OutputMimeType]) + if tt.config.HideOutputs { + require.Equal(t, openinference.RedactedValue, attrs[openinference.OutputValue]) + } else { + require.Equal(t, string(respBody), attrs[openinference.OutputValue]) + } + require.Equal(t, tt.expectedStatus, actualSpan.Status) + }) + } +} diff --git a/internal/tracing/openinference/openai/image_generation_test.go b/internal/tracing/openinference/openai/image_generation_test.go new file mode 100644 index 0000000000..5b5b4331ad --- /dev/null +++ b/internal/tracing/openinference/openai/image_generation_test.go @@ -0,0 +1,188 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package openai + +import ( + "encoding/json" + "strconv" + "testing" + + openaisdk "github.com/openai/openai-go/v2" + openaiparam "github.com/openai/openai-go/v2/packages/param" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/sdk/trace" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/envoyproxy/ai-gateway/internal/testing/testotel" + tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" + "github.com/envoyproxy/ai-gateway/internal/tracing/openinference" +) + +var ( + // Test data constants following chat completion pattern + basicImageReq = &openaisdk.ImageGenerateParams{ + Model: openaisdk.ImageModelGPTImage1, + Prompt: "a hummingbird", + Size: openaisdk.ImageGenerateParamsSize1024x1024, + Quality: openaisdk.ImageGenerateParamsQualityHigh, + ResponseFormat: openaisdk.ImageGenerateParamsResponseFormatB64JSON, + N: openaiparam.NewOpt[int64](1), + } + basicImageReqBody = mustJSON(basicImageReq) + + basicImageResp = &openaisdk.ImagesResponse{ + Data: []openaisdk.Image{{URL: "https://example.com/img.png"}}, + Size: openaisdk.ImagesResponseSize1024x1024, + Usage: openaisdk.ImagesResponseUsage{ + InputTokens: 8, + OutputTokens: 1056, + TotalTokens: 1064, + }, + } + basicImageRespBody = mustJSON(basicImageResp) +) + +func TestImageGenerationRecorder_StartParams(t *testing.T) { + tests := []struct { + name string + req *openaisdk.ImageGenerateParams + reqBody []byte + expectedSpanName string + }{ + { + name: "basic request", + req: basicImageReq, + reqBody: basicImageReqBody, + expectedSpanName: "ImageGeneration", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := NewImageGenerationRecorder(nil) + + spanName, opts := recorder.StartParams(tt.req, tt.reqBody) + actualSpan := testotel.RecordNewSpan(t, spanName, opts...) + + require.Equal(t, tt.expectedSpanName, actualSpan.Name) + require.Equal(t, oteltrace.SpanKindInternal, actualSpan.SpanKind) + }) + } +} + +func TestImageGenerationRecorder_RecordRequest(t *testing.T) { + tests := []struct { + name string + req *openaisdk.ImageGenerateParams + reqBody []byte + expectedAttrs []attribute.KeyValue + }{ + { + name: "basic request", + req: basicImageReq, + reqBody: basicImageReqBody, + expectedAttrs: []attribute.KeyValue{ + attribute.String(openinference.SpanKind, openinference.SpanKindLLM), + attribute.String(openinference.LLMSystem, openinference.LLMSystemOpenAI), + attribute.String(openinference.LLMModelName, openaisdk.ImageModelGPTImage1), + attribute.String(openinference.InputValue, string(basicImageReqBody)), + attribute.String(openinference.InputMimeType, openinference.MimeTypeJSON), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := NewImageGenerationRecorder(nil) + + actualSpan := testotel.RecordWithSpan(t, func(span oteltrace.Span) bool { + recorder.RecordRequest(span, tt.req, tt.reqBody) + return false + }) + + // Check that key attributes are present + attrs := attributesToMap(actualSpan.Attributes) + require.Equal(t, openinference.SpanKindLLM, attrs[openinference.SpanKind]) + require.Equal(t, openinference.LLMSystemOpenAI, attrs[openinference.LLMSystem]) + require.Equal(t, openaisdk.ImageModelGPTImage1, attrs[openinference.LLMModelName]) + require.Equal(t, string(basicImageReqBody), attrs[openinference.InputValue]) + require.Equal(t, openinference.MimeTypeJSON, attrs[openinference.InputMimeType]) + }) + } +} + +func TestImageGenerationRecorder_RecordResponse(t *testing.T) { + tests := []struct { + name string + respBody []byte + expectedAttrs []attribute.KeyValue + expectedStatus trace.Status + }{ + { + name: "successful response", + respBody: basicImageRespBody, + expectedAttrs: []attribute.KeyValue{}, + expectedStatus: trace.Status{Code: codes.Ok, Description: ""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := NewImageGenerationRecorder(nil) + + resp := &openaisdk.ImagesResponse{} + err := json.Unmarshal(tt.respBody, resp) + require.NoError(t, err) + + actualSpan := testotel.RecordWithSpan(t, func(span oteltrace.Span) bool { + recorder.RecordResponse(span, resp) + return false + }) + + // Check that key attributes are present + _ = attributesToMap(actualSpan.Attributes) + require.Equal(t, trace.Status{Code: codes.Ok, Description: ""}, actualSpan.Status) + }) + } +} + +func TestImageGenerationRecorder_RecordResponseOnError(t *testing.T) { + recorder := NewImageGenerationRecorder(nil) + + actualSpan := testotel.RecordWithSpan(t, func(span oteltrace.Span) bool { + recorder.RecordResponseOnError(span, 400, []byte(`{"error":{"message":"Invalid request","type":"invalid_request_error"}}`)) + return false + }) + + require.Equal(t, trace.Status{ + Code: codes.Error, + Description: `Error code: 400 - {"error":{"message":"Invalid request","type":"invalid_request_error"}}`, + }, actualSpan.Status) +} + +// attributesToMap converts attribute KeyValue to a simple map for assertions. +func attributesToMap(kvs []attribute.KeyValue) map[string]string { + m := make(map[string]string, len(kvs)) + for _, kv := range kvs { + switch kv.Value.Type() { + case attribute.STRING: + m[string(kv.Key)] = kv.Value.AsString() + case attribute.BOOL: + m[string(kv.Key)] = strconv.FormatBool(kv.Value.AsBool()) + case attribute.INT64: + m[string(kv.Key)] = strconv.FormatInt(kv.Value.AsInt64(), 10) + case attribute.FLOAT64: + m[string(kv.Key)] = strconv.FormatFloat(kv.Value.AsFloat64(), 'f', -1, 64) + default: + m[string(kv.Key)] = kv.Value.AsString() + } + } + return m +} + +var _ tracing.ImageGenerationRecorder = (*ImageGenerationRecorder)(nil) diff --git a/internal/tracing/tracing.go b/internal/tracing/tracing.go index c2810ff563..3e2b839c3f 100644 --- a/internal/tracing/tracing.go +++ b/internal/tracing/tracing.go @@ -25,10 +25,11 @@ import ( var _ tracing.Tracing = (*tracingImpl)(nil) type tracingImpl struct { - chatCompletionTracer tracing.ChatCompletionTracer - completionTracer tracing.CompletionTracer - embeddingsTracer tracing.EmbeddingsTracer - mcpTracer tracing.MCPTracer + chatCompletionTracer tracing.ChatCompletionTracer + completionTracer tracing.CompletionTracer + imageGenerationTracer tracing.ImageGenerationTracer + embeddingsTracer tracing.EmbeddingsTracer + mcpTracer tracing.MCPTracer // shutdown is nil when we didn't create tp. shutdown func(context.Context) error } @@ -48,6 +49,11 @@ func (t *tracingImpl) EmbeddingsTracer() tracing.EmbeddingsTracer { return t.embeddingsTracer } +// ImageGenerationTracer implements the same method as documented on api.Tracing. +func (t *tracingImpl) ImageGenerationTracer() tracing.ImageGenerationTracer { + return t.imageGenerationTracer +} + func (t *tracingImpl) MCPTracer() tracing.MCPTracer { return t.mcpTracer } @@ -156,6 +162,7 @@ func NewTracingFromEnv(ctx context.Context, stdout io.Writer, headerAttributeMap // Default to OpenInference trace span semantic conventions. chatRecorder := openai.NewChatCompletionRecorderFromEnv() + imageRecorder := openai.NewImageGenerationRecorderFromEnv() completionRecorder := openai.NewCompletionRecorderFromEnv() embeddingsRecorder := openai.NewEmbeddingsRecorderFromEnv() @@ -167,6 +174,11 @@ func NewTracingFromEnv(ctx context.Context, stdout io.Writer, headerAttributeMap chatRecorder, headerAttrs, ), + imageGenerationTracer: newImageGenerationTracer( + tracer, + propagator, + imageRecorder, + ), completionTracer: newCompletionTracer( tracer, propagator, diff --git a/internal/tracing/tracing_test.go b/internal/tracing/tracing_test.go index e90c813adf..769cb4e129 100644 --- a/internal/tracing/tracing_test.go +++ b/internal/tracing/tracing_test.go @@ -553,6 +553,45 @@ func TestNewTracingFromEnv_OTLPHeaders(t *testing.T) { require.Equal(t, expectedAuthorization, <-actualAuthorization) } +// TestNewTracingFromEnv_HeaderAttributeMapping verifies that headerAttributeMapping +// passed to NewTracingFromEnv is applied by tracers to set span attributes. +func TestNewTracingFromEnv_HeaderAttributeMapping(t *testing.T) { + collector := testotel.StartOTLPCollector() + t.Cleanup(collector.Close) + collector.SetEnv(t.Setenv) + + mapping := map[string]string{ + "x-session-id": "session.id", + "x-user-id": "user.id", + } + + result, err := NewTracingFromEnv(t.Context(), io.Discard, mapping) + require.NoError(t, err) + t.Cleanup(func() { _ = result.Shutdown(context.Background()) }) + + headers := map[string]string{ + "x-session-id": "abc123", + "x-user-id": "user456", + } + headerMutation := &extprocv3.HeaderMutation{} + + tr := result.ChatCompletionTracer() + req := &openai.ChatCompletionRequest{Model: openai.ModelGPT5Nano} + span := tr.StartSpanAndInjectHeaders(t.Context(), headers, headerMutation, req, []byte("{}")) + require.NotNil(t, span) + span.EndSpan() + + v1Span := collector.TakeSpan() + require.NotNil(t, v1Span) + + attrs := make(map[string]string) + for _, kv := range v1Span.Attributes { + attrs[kv.Key] = kv.Value.GetStringValue() + } + require.Equal(t, "abc123", attrs["session.id"]) + require.Equal(t, "user456", attrs["user.id"]) +} + // TestNewTracingFromEnv_Embeddings_Redaction tests that the OpenInference // environment variables (OPENINFERENCE_HIDE_EMBEDDINGS_TEXT and OPENINFERENCE_HIDE_EMBEDDINGS_VECTORS) // work correctly to redact sensitive data from embeddings spans, following the OpenInference diff --git a/site/docs/capabilities/llm-integrations/supported-endpoints.md b/site/docs/capabilities/llm-integrations/supported-endpoints.md index fa1c4ac8ad..42d33e38c6 100644 --- a/site/docs/capabilities/llm-integrations/supported-endpoints.md +++ b/site/docs/capabilities/llm-integrations/supported-endpoints.md @@ -156,6 +156,40 @@ curl -H "Content-Type: application/json" \ - OpenAI - Any OpenAI-compatible provider that supports embeddings, including Azure OpenAI. +### Image Generation + +**Endpoint:** `POST /v1/images/generations` + +**Status:** ✅ Supported + +**Description:** Generate one or more images from a text prompt using OpenAI-compatible models. + +**Features:** + +- **Non-streaming responses**: Returns JSON payload with image URLs or base64 content +- **Model selection**: Via request body `model` or `x-ai-eg-model` header +- **Parameters**: `prompt`, `size`, `n`, `quality`, `response_format` +- **Metrics**: Records image count, model, and size; token usage when provided +- **Provider fallback and load balancing** + +**Supported Providers:** + +- OpenAI +- Any OpenAI-compatible provider that supports image generations + +**Example:** + +```bash +curl -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-image-1", + "prompt": "a serene mountain landscape at sunrise in watercolor", + "size": "1024x1024", + "n": 1 + }' \ + $GATEWAY_URL/v1/images/generations +``` + ### Models **Endpoint:** `GET /v1/models` @@ -194,26 +228,26 @@ curl $GATEWAY_URL/v1/models The following table summarizes which providers support which endpoints: -| Provider | Chat Completions | Completions | Embeddings | Anthropic Messages | Notes | -| ----------------------------------------------------------------------------------------------------- | :--------------: | :---------: | :--------: | :----------------: | -------------------------------------------------------------------------------------------------------------------- | -| [OpenAI](https://platform.openai.com/docs/api-reference) | ✅ | ✅ | ✅ | ❌ | | -| [AWS Bedrock](https://docs.aws.amazon.com/bedrock/latest/APIReference/) | ✅ | 🚧 | 🚧 | ❌ | Via API translation | -| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) | ✅ | 🚧 | ✅ | ❌ | Via API translation or via [OpenAI-compatible API](https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest) | -| [Google Gemini](https://ai.google.dev/gemini-api/docs/openai) | ✅ | ⚠️ | ✅ | ❌ | Via OpenAI-compatible API | -| [Groq](https://console.groq.com/docs/openai) | ✅ | ❌ | ❌ | ❌ | Via OpenAI-compatible API | -| [Grok](https://docs.x.ai/docs/api-reference) | ✅ | ⚠️ | ❌ | ❌ | Via OpenAI-compatible API | -| [Together AI](https://docs.together.ai/docs/openai-api-compatibility) | ⚠️ | ⚠️ | ⚠️ | ❌ | Via OpenAI-compatible API | -| [Cohere](https://docs.cohere.com/v2/docs/compatibility-api) | ⚠️ | ⚠️ | ⚠️ | ❌ | Via OpenAI-compatible API | -| [Mistral](https://docs.mistral.ai/api/) | ⚠️ | ⚠️ | ⚠️ | ❌ | Via OpenAI-compatible API | -| [DeepInfra](https://deepinfra.com/docs/inference) | ✅ | ⚠️ | ✅ | ❌ | Via OpenAI-compatible API | -| [DeepSeek](https://api-docs.deepseek.com/) | ⚠️ | ⚠️ | ❌ | ❌ | Via OpenAI-compatible API | -| [Hunyuan](https://cloud.tencent.com/document/product/1729/111007) | ⚠️ | ⚠️ | ⚠️ | ❌ | Via OpenAI-compatible API | -| [Tencent LLM Knowledge Engine](https://www.tencentcloud.com/document/product/1255/70381) | ⚠️ | ❌ | ❌ | ❌ | Via OpenAI-compatible API | -| [Tetrate Agent Router Service (TARS)](https://router.tetrate.ai/) | ⚠️ | ⚠️ | ⚠️ | ❌ | Via OpenAI-compatible API | -| [Google Vertex AI](https://cloud.google.com/vertex-ai/docs/reference/rest) | ✅ | 🚧 | 🚧 | ❌ | Via OpenAI-compatible API | -| [Anthropic on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude) | ✅ | ❌ | 🚧 | ✅ | Via OpenAI-compatible API and Native Anthropic API | -| [SambaNova](https://docs.sambanova.ai/sambastudio/latest/open-ai-api.html) | ✅ | ⚠️ | ✅ | ❌ | Via OpenAI-compatible API | -| [Anthropic](https://docs.claude.com/en/home) | ✅ | ❌ | ❌ | ✅ | Via OpenAI-compatible API and Native Anthropic API | +| Provider | Chat Completions | Completions | Embeddings | Image Generation | Anthropic Messages | Notes | +| ----------------------------------------------------------------------------------------------------- | :--------------: | :---------: | :--------: | :--------------: | :----------------: | -------------------------------------------------------------------------------------------------------------------- | +| [OpenAI](https://platform.openai.com/docs/api-reference) | ✅ | ✅ | ✅ | ✅ | ❌ | | +| [AWS Bedrock](https://docs.aws.amazon.com/bedrock/latest/APIReference/) | ✅ | 🚧 | 🚧 | ❌ | ❌ | Via API translation | +| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) | ✅ | 🚧 | ✅ | ⚠️ | ❌ | Via API translation or via [OpenAI-compatible API](https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest) | +| [Google Gemini](https://ai.google.dev/gemini-api/docs/openai) | ✅ | ⚠️ | ✅ | ⚠️ | ❌ | Via OpenAI-compatible API | +| [Groq](https://console.groq.com/docs/openai) | ✅ | ❌ | ❌ | ❌ | ❌ | Via OpenAI-compatible API | +| [Grok](https://docs.x.ai/docs/api-reference) | ✅ | ⚠️ | ❌ | ⚠️ | ❌ | Via OpenAI-compatible API | +| [Together AI](https://docs.together.ai/docs/openai-api-compatibility) | ⚠️ | ⚠️ | ⚠️ | ⚠️ | ❌ | Via OpenAI-compatible API | +| [Cohere](https://docs.cohere.com/v2/docs/compatibility-api) | ⚠️ | ⚠️ | ⚠️ | ❌ | ❌ | Via OpenAI-compatible API | +| [Mistral](https://docs.mistral.ai/api/) | ⚠️ | ⚠️ | ⚠️ | ❌ | ❌ | Via OpenAI-compatible API | +| [DeepInfra](https://deepinfra.com/docs/inference) | ✅ | ⚠️ | ✅ | ⚠️ | ❌ | Via OpenAI-compatible API | +| [DeepSeek](https://api-docs.deepseek.com/) | ⚠️ | ⚠️ | ❌ | ❌ | ❌ | Via OpenAI-compatible API | +| [Hunyuan](https://cloud.tencent.com/document/product/1729/111007) | ⚠️ | ⚠️ | ⚠️ | ❌ | ❌ | Via OpenAI-compatible API | +| [Tencent LLM Knowledge Engine](https://www.tencentcloud.com/document/product/1255/70381) | ⚠️ | ❌ | ❌ | ❌ | ❌ | Via OpenAI-compatible API | +| [Tetrate Agent Router Service (TARS)](https://router.tetrate.ai/) | ⚠️ | ⚠️ | ⚠️ | ❌ | ❌ | Via OpenAI-compatible API | +| [Google Vertex AI](https://cloud.google.com/vertex-ai/docs/reference/rest) | ✅ | 🚧 | 🚧 | ❌ | ❌ | Via OpenAI-compatible API | +| [Anthropic on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude) | ✅ | ❌ | 🚧 | ❌ | ✅ | Via OpenAI-compatible API and Native Anthropic API | +| [SambaNova](https://docs.sambanova.ai/sambastudio/latest/open-ai-api.html) | ✅ | ⚠️ | ✅ | ❌ | ❌ | Via OpenAI-compatible API | +| [Anthropic](https://docs.claude.com/en/home) | ✅ | ❌ | ❌ | ❌ | ✅ | Via OpenAI-compatible API and Native Anthropic API | - ✅ - Supported and Tested on Envoy AI Gateway CI - ⚠️️ - Expected to work based on provider documentation, but not tested on the CI. diff --git a/site/docs/capabilities/observability/metrics.md b/site/docs/capabilities/observability/metrics.md index 3ced359df1..d67e947411 100644 --- a/site/docs/capabilities/observability/metrics.md +++ b/site/docs/capabilities/observability/metrics.md @@ -35,6 +35,7 @@ Each metric comes with some default attributes such as: - `chat`: For `/v1/chat/completions` endpoint. - `completion`: For `/v1/completions` endpoint. - `embedding`: For `/v1/embeddings` endpoint. + - `image_generation`: For `/v1/images/generations` endpoint. - `messages`: For `/anthropic/v1/messages` endpoint. - `gen_ai.original.model` - The original model name from the request body - `gen_ai.request.model` - The model name requested (may be overridden) diff --git a/tests/extproc/testupstream_test.go b/tests/extproc/testupstream_test.go index 1bae78b8c0..124cb33b5e 100644 --- a/tests/extproc/testupstream_test.go +++ b/tests/extproc/testupstream_test.go @@ -141,6 +141,30 @@ func TestWithTestUpstream(t *testing.T) { // expResponseBodyFunc is a function to check the response body. This can be used instead of the expResponseBody field. expResponseBodyFunc func(require.TestingT, []byte) }{ + { + name: "openai - /v1/images/generations", + backend: "openai", + path: "/v1/images/generations", + method: http.MethodPost, + requestBody: `{"model":"gpt-image-1-mini","prompt":"a cat wearing sunglasses","size":"1024x1024","quality":"low"}`, + expPath: "/v1/images/generations", + responseBody: `{"created":1736890000,"data":[{"url":"https://example.com/image1.png"}],"model":"gpt-image-1-mini","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`, + expStatus: http.StatusOK, + expResponseBody: `{"created":1736890000,"data":[{"url":"https://example.com/image1.png"}],"model":"gpt-image-1-mini","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`, + }, + { + name: "openai - /v1/images/generations - non json upstream error mapped to OpenAI", + backend: "openai", + path: "/v1/images/generations", + method: http.MethodPost, + requestBody: `{"model":"dall-e-3","prompt":"a scenic beach"}`, + expPath: "/v1/images/generations", + responseHeaders: "content-type:text/plain", + responseStatus: strconv.Itoa(http.StatusServiceUnavailable), + responseBody: `backend timeout`, + expStatus: http.StatusServiceUnavailable, + expResponseBody: `{"error":{"type":"OpenAIBackendError","message":"backend timeout","code":"503"}}`, + }, { name: "unknown path", path: "/unknown", @@ -979,7 +1003,10 @@ data: {"type":"message_stop" } } defer func() { _ = resp.Body.Close() }() - failIf5xx(t, resp, &was5xx) + // Only fail-fast on unexpected 5xx. Some test cases intentionally expect 5xx. + if tc.expStatus < http.StatusInternalServerError { + failIf5xx(t, resp, &was5xx) + } lastBody, lastErr = io.ReadAll(resp.Body) if lastErr != nil { diff --git a/tests/extproc/vcr/envoy.yaml b/tests/extproc/vcr/envoy.yaml index c0b735c925..4c76d3b527 100644 --- a/tests/extproc/vcr/envoy.yaml +++ b/tests/extproc/vcr/envoy.yaml @@ -96,6 +96,10 @@ static_resources: path: "/v1/embeddings" route: cluster: openai + - match: + path: "/v1/images/generations" + route: + cluster: openai - match: path: "/v1/models" route: diff --git a/tests/extproc/vcr/image_generation_test.go b/tests/extproc/vcr/image_generation_test.go new file mode 100644 index 0000000000..2b9e54c490 --- /dev/null +++ b/tests/extproc/vcr/image_generation_test.go @@ -0,0 +1,52 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package vcr + +import ( + "fmt" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/tests/internal/testopenai" +) + +func TestOpenAIImageGeneration(t *testing.T) { + env := startTestEnvironment(t, extprocBin, extprocConfig, nil, envoyConfig) + + listenerPort := env.EnvoyListenerPort() + + cassettes := testopenai.ImageCassettes() + + was5xx := false + for _, cassette := range cassettes { + if was5xx { + return // stop early on infrastructure failures to avoid cascading errors + } + t.Run(cassette.String(), func(t *testing.T) { + req, err := testopenai.NewRequest(t.Context(), fmt.Sprintf("http://localhost:%d", listenerPort), cassette) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + if resp.StatusCode == http.StatusBadGateway { + was5xx = true + } + + expectedBody := testopenai.ResponseBody(cassette) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, expectedBody, string(body)) + }) + } +} diff --git a/tests/extproc/vcr/otel_image_generation_metrics_test.go b/tests/extproc/vcr/otel_image_generation_metrics_test.go new file mode 100644 index 0000000000..a812821d61 --- /dev/null +++ b/tests/extproc/vcr/otel_image_generation_metrics_test.go @@ -0,0 +1,82 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package vcr + +import ( + "fmt" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/tests/internal/testopenai" +) + +// otelImageGenerationMetricsTestCase defines the expected behavior for each cassette. +type otelImageGenerationMetricsTestCase struct { + cassette testopenai.Cassette + isError bool // whether this is an error response. +} + +// buildOtelImageGenerationMetricsTestCases returns all test cases with their expected behaviors. +func buildOtelImageGenerationMetricsTestCases() []otelImageGenerationMetricsTestCase { + var cases []otelImageGenerationMetricsTestCase + for _, cassette := range testopenai.ImageCassettes() { + tc := otelImageGenerationMetricsTestCase{cassette: cassette} + // Currently we only have happy-path cassettes for image generation + cases = append(cases, tc) + } + return cases +} + +// TestOtelOpenAIImageGeneration_metrics tests that metrics are properly exported via OTLP for image generation requests. +func TestOtelOpenAIImageGeneration_metrics(t *testing.T) { + env := setupOtelTestEnvironment(t) + listenerPort := env.EnvoyListenerPort() + was5xx := false + + for _, tc := range buildOtelImageGenerationMetricsTestCases() { + if was5xx { + return // rather than also failing subsequent tests, which confuses root cause. + } + + t.Run(tc.cassette.String(), func(t *testing.T) { + // Send request. + req, err := testopenai.NewRequest(t.Context(), fmt.Sprintf("http://localhost:%d", listenerPort), tc.cassette) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + failIf5xx(t, resp, &was5xx) + + // Always read the content. + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + // Get the span to extract duration and model attributes. + span := env.collector.TakeSpan() + require.NotNil(t, span) + + // Collect all metrics within the timeout period. + // Image generation should have 2 metrics: token usage + request duration + allMetrics := env.collector.TakeMetrics(2) + metrics := requireScopeMetrics(t, allMetrics) + + // For image generation spans, we record llm.model_name on request; no response model attribute exists. + // In non-override cases, original = request = response. + requestModel := getSpanAttributeString(span.Attributes, "llm.model_name") + originalModel := requestModel + responseModel := requestModel + + // Verify metrics. + verifyTokenUsageMetricsWithOriginal(t, "image_generation", metrics, span, originalModel, requestModel, responseModel, tc.isError) + verifyRequestDurationMetricsWithOriginal(t, "image_generation", metrics, span, originalModel, requestModel, responseModel, tc.isError) + }) + } +} diff --git a/tests/extproc/vcr/otel_image_generation_tracing_test.go b/tests/extproc/vcr/otel_image_generation_tracing_test.go new file mode 100644 index 0000000000..0cddd607ce --- /dev/null +++ b/tests/extproc/vcr/otel_image_generation_tracing_test.go @@ -0,0 +1,83 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package vcr + +import ( + "encoding/hex" + "fmt" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/tests/internal/testopenai" + "github.com/envoyproxy/ai-gateway/tests/internal/testopeninference" +) + +// TestOtelOpenAIImageGeneration_tracing validates that image generation spans +// emitted by the gateway match the OpenInference reference spans for the same cassette. +func TestOtelOpenAIImageGeneration_tracing(t *testing.T) { + env := setupOtelTestEnvironment(t) + listenerPort := env.EnvoyListenerPort() + + was5xx := false + for _, cassette := range testopenai.ImageCassettes() { + if was5xx { + return // avoid cascading failures obscuring the first root cause + } + + expected, err := testopeninference.GetSpan(t.Context(), io.Discard, cassette) + require.NoError(t, err) + + t.Run(cassette.String(), func(t *testing.T) { + // Send request. + req, err := testopenai.NewRequest(t.Context(), fmt.Sprintf("http://localhost:%d", listenerPort), cassette) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + failIf5xx(t, resp, &was5xx) + + // Always read the content. + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + span := env.collector.TakeSpan() + testopeninference.RequireSpanEqual(t, expected, span) + + // Also drain any metrics that might have been sent. + _ = env.collector.DrainMetrics() + }) + } +} + +// TestOtelOpenAIImageGeneration_propagation verifies that the image generation LLM span +// participates in the incoming trace when W3C trace context is provided. +func TestOtelOpenAIImageGeneration_propagation(t *testing.T) { + env := setupOtelTestEnvironment(t) + listenerPort := env.EnvoyListenerPort() + + req, err := testopenai.NewRequest(t.Context(), fmt.Sprintf("http://localhost:%d", listenerPort), testopenai.CassetteImageGenerationBasic) + require.NoError(t, err) + + traceID := "12345678901234567890123456789012" + req.Header.Add("traceparent", "00-"+traceID+"-1234567890123456-01") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + + span := env.collector.TakeSpan() + require.NotNil(t, span) + actualTraceID := hex.EncodeToString(span.TraceId) + require.Equal(t, traceID, actualTraceID) +} diff --git a/tests/internal/testopenai/cassettes.go b/tests/internal/testopenai/cassettes.go index a79e583533..bf30f5f6c6 100644 --- a/tests/internal/testopenai/cassettes.go +++ b/tests/internal/testopenai/cassettes.go @@ -135,6 +135,11 @@ const ( // CassetteEmbeddingsBadRequest tests request with multiple validation errors. CassetteEmbeddingsBadRequest + // Cassettes for the OpenAI /v1/images/generations endpoint. + + // CassetteImageGenerationBasic is a basic image generation request with model and prompt. + CassetteImageGenerationBasic + // Cassettes for Azure OpenAI Service. // CassetteAzureChatBasic is the same as CassetteChatBasic, except using @@ -189,6 +194,8 @@ var stringValues = map[Cassette]string{ CassetteEmbeddingsMaxTokens: "embeddings-max-tokens", CassetteEmbeddingsWhitespace: "embeddings-whitespace", CassetteEmbeddingsBadRequest: "embeddings-bad-request", + + CassetteImageGenerationBasic: "image-generation-basic", } // String returns the string representation of the cassette name. @@ -213,6 +220,9 @@ func NewRequest(ctx context.Context, baseURL string, cassette Cassette) (*http.R } else if r, ok := embeddingsRequests[cassette]; ok { path := buildPath(cassette, "/embeddings", baseURL, r) return newRequest(ctx, cassette, path, r) + } else if r, ok := imageRequests[cassette]; ok { + path := buildPath(cassette, "/images/generations", baseURL, r) + return newRequest(ctx, cassette, path, r) } return nil, fmt.Errorf("unknown cassette: %s", cassette) } diff --git a/tests/internal/testopenai/cassettes/image-generation-basic.yaml b/tests/internal/testopenai/cassettes/image-generation-basic.yaml new file mode 100644 index 0000000000..e0f0c251ff --- /dev/null +++ b/tests/internal/testopenai/cassettes/image-generation-basic.yaml @@ -0,0 +1,60 @@ +--- +version: 2 +interactions: +- id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 139 + host: api.openai.com + body: "{\n \"model\": \"gpt-image-1-mini\",\n \"prompt\": \"A simple black-and-white line drawing of a cat playing with yarn\",\n \"size\": \"1024x1024\",\n \"quality\": \"low\"\n}" + headers: + Accept-Encoding: + - gzip + Content-Length: + - "139" + Content-Type: + - application/json + User-Agent: + - Go-http-client/1.1 + url: https://api.openai.com/v1/images/generations + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: 544 + body: "{\n \"created\": 1760092079,\n \"data\": [\n {\n \"url\": \"https://oaidalleapiprodscus.blob.core.windows.net/private/org-lKxIBdltrjcFbzboW0t1hthB/user-EdKNxj0PjYS29CfjUcR77ybc/img-3Ah2ERYmeJLYrRlxnuWPFpXG.png?st=2025-10-10T09%3A27%3A58Z&se=2025-10-10T11%3A27%3A58Z&sp=r&sv=2024-08-04&sr=b&rscd=inline&rsct=image/png&skoid=475fd488-6c59-44a5-9aa9-31c4db451bea&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2025-10-10T10%3A27%3A58Z&ske=2025-10-11T10%3A27%3A58Z&sks=b&skv=2024-08-04&sig=IlzMauzcIHqE5/JLQsiMglI0mqbGlrpQC/mvJz4TYuc%3D\"\n }\n ]\n}" + headers: + Alt-Svc: + - h3=":443"; ma=86400 + Cf-Cache-Status: + - DYNAMIC + Cf-Ray: + - 98c56d4f2dd889f6-BOM + Content-Length: + - "544" + Content-Type: + - application/json + Date: + - Fri, 10 Oct 2025 10:27:59 GMT + Openai-Processing-Ms: + - "13023" + Openai-Project: + - proj_Ro6stNUOKIiLVqRr1zdjaFoh + Openai-Version: + - "2020-10-01" + Server: + - cloudflare + Strict-Transport-Security: + - max-age=31536000; includeSubDomains; preload + X-Content-Type-Options: + - nosniff + X-Envoy-Upstream-Service-Time: + - "13027" + X-Request-Id: + - req_54999507941c487a8f1997287aaa8cf6 + status: 200 OK + code: 200 + duration: 14.290895958s diff --git a/tests/internal/testopenai/cassettes_test.go b/tests/internal/testopenai/cassettes_test.go index e7f9de3111..6c6ef99fa1 100644 --- a/tests/internal/testopenai/cassettes_test.go +++ b/tests/internal/testopenai/cassettes_test.go @@ -12,6 +12,7 @@ import ( "net/http" "os" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -27,7 +28,12 @@ func testNewRequest[R any](t *testing.T, tests []cassetteTestCase[R]) { // documented way to backfill cassettes. server, err := NewServer(os.Stdout, 0) require.NoError(t, err) - defer server.Close() + defer func() { + // This sleep is required to wait until large cassettes are recorded. + // Remove this sleep when there is a proper way to wait for cassettes to be recorded. + <-time.After(5 * time.Second) + server.Close() + }() baseURL := server.URL() diff --git a/tests/internal/testopenai/image_requests.go b/tests/internal/testopenai/image_requests.go new file mode 100644 index 0000000000..98b060cb66 --- /dev/null +++ b/tests/internal/testopenai/image_requests.go @@ -0,0 +1,31 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package testopenai + +// ImageCassettes returns a slice of all cassettes for image generation. +func ImageCassettes() []Cassette { + return cassettes(imageRequests) +} + +// imageGenerationRequest is a minimal request body for OpenAI image generation. +// We avoid importing the OpenAI SDK in tests to keep dependencies light. +type imageGenerationRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + // Optional fields like size/quality/response_format can be added later if needed. + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` +} + +// imageRequests contains the actual request body for each image generation cassette. +var imageRequests = map[Cassette]*imageGenerationRequest{ + CassetteImageGenerationBasic: { + Model: "gpt-image-1-mini", + Prompt: "A simple black-and-white line drawing of a cat playing with yarn", + Size: "1024x1024", + Quality: "low", + }, +} diff --git a/tests/internal/testopenai/image_requests_test.go b/tests/internal/testopenai/image_requests_test.go new file mode 100644 index 0000000000..3309ab85a7 --- /dev/null +++ b/tests/internal/testopenai/image_requests_test.go @@ -0,0 +1,27 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package testopenai + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestImageCassettes(t *testing.T) { + testCassettes(t, ImageCassettes(), imageRequests) +} + +func TestNewRequestImages(t *testing.T) { + tests, err := buildTestCases(t, imageRequests) + require.NoError(t, err) + for i := range tests { + // Currently only basic happy-path image generation cassette + tests[i].expectedStatus = http.StatusOK + } + testNewRequest(t, tests) +} diff --git a/tests/internal/testopeninference/openai_proxy.py b/tests/internal/testopeninference/openai_proxy.py index 525c5ff8cf..533f252c45 100644 --- a/tests/internal/testopeninference/openai_proxy.py +++ b/tests/internal/testopeninference/openai_proxy.py @@ -86,6 +86,20 @@ async def azure_embeddings(deployment: str, request: Request) -> Response: client.embeddings.create ) +@app.post("/v1/images/generations") +async def images_generations(request: Request) -> Response: + return await handle_openai_request( + request, + client.images.generate + ) + +@app.post("/openai/deployments/{deployment}/images/generations") +async def azure_images_generations(deployment: str, request: Request) -> Response: + return await handle_openai_request( + request, + client.images.generate + ) + async def handle_openai_request( request: Request, client_method, @@ -95,7 +109,7 @@ async def handle_openai_request( try: if request_data is None: request_data = await request.json() - logger.info(f"Received request: {json.dumps(request_data)}") + logger.info(f"Received request: {json.dumps(request_data)[:600]}") cassette_name = request.headers.get('X-Cassette-Name') extra_headers = {"X-Cassette-Name": cassette_name} if cassette_name else {} diff --git a/tests/internal/testopeninference/spans/image-generation-basic.json b/tests/internal/testopeninference/spans/image-generation-basic.json new file mode 100644 index 0000000000..8e728ca871 --- /dev/null +++ b/tests/internal/testopeninference/spans/image-generation-basic.json @@ -0,0 +1,13 @@ +{ + "flags": 256, + "name": "ImageGeneration", + "kind": "SPAN_KIND_INTERNAL", + "attributes": [ + {"key": "llm.system", "value": {"stringValue": "openai"}}, + {"key": "llm.model_name", "value": {"stringValue": "gpt-image-1-mini"}}, + {"key": "openinference.span.kind", "value": {"stringValue": "LLM"}}, + {"key": "input.mime_type", "value": {"stringValue": "application/json"}}, + {"key": "input.value", "value": {"stringValue": "{\"model\":\"gpt-image-1-mini\",\"prompt\":\"A simple black-and-white line drawing of a cat playing with yarn\",\"quality\":\"low\",\"size\":\"1024x1024\"}"}} + ], + "status": {} +} diff --git a/tests/internal/testopeninference/spans_test.go b/tests/internal/testopeninference/spans_test.go index 1502e6c38e..5849698bb4 100644 --- a/tests/internal/testopeninference/spans_test.go +++ b/tests/internal/testopeninference/spans_test.go @@ -23,6 +23,7 @@ func TestGetAllSpans(t *testing.T) { {"ChatCompletion", testopenai.ChatCassettes()}, {"Completion", testopenai.CompletionCassettes()}, {"CreateEmbeddings", testopenai.EmbeddingsCassettes()}, + {"ImageGeneration", testopenai.ImageCassettes()}, } for _, tc := range tests {