Skip to content

Commit ec5629e

Browse files
feat(extproc): implement image generation processor with OpenAI integration
- Add comprehensive image generation processor with request/response handling - Add server integration for image generation processing - Include comprehensive test coverage for processor functionality Signed-off-by: Hrushikesh Patil <[email protected]>
1 parent 24a5e94 commit ec5629e

File tree

3 files changed

+137
-106
lines changed

3 files changed

+137
-106
lines changed

internal/extproc/imagegeneration_processor.go

Lines changed: 90 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"context"
1111
"encoding/json"
1212
"fmt"
13+
"io"
1314
"log/slog"
1415
"strconv"
1516

@@ -18,14 +19,14 @@ import (
1819
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
1920
"google.golang.org/protobuf/types/known/structpb"
2021

21-
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
2222
"github.com/envoyproxy/ai-gateway/internal/extproc/backendauth"
2323
"github.com/envoyproxy/ai-gateway/internal/extproc/headermutator"
2424
"github.com/envoyproxy/ai-gateway/internal/extproc/translator"
2525
"github.com/envoyproxy/ai-gateway/internal/filterapi"
2626
"github.com/envoyproxy/ai-gateway/internal/internalapi"
2727
"github.com/envoyproxy/ai-gateway/internal/metrics"
2828
tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api"
29+
openaisdk "github.com/openai/openai-go/v2"
2930
)
3031

3132
// ImageGenerationProcessorFactory returns a factory method to instantiate the image generation processor.
@@ -67,7 +68,7 @@ type imageGenerationProcessorRouterFilter struct {
6768
// originalRequestBody is the original request body that is passed to the upstream filter.
6869
// This is used to perform the transformation of the request body on the original input
6970
// when the request is retried.
70-
originalRequestBody *openai.ImageGenerationRequest
71+
originalRequestBody *openaisdk.ImageGenerateParams
7172
originalRequestBodyRaw []byte
7273
// tracer is the tracer used for requests.
7374
tracer tracing.ImageGenerationTracer
@@ -107,11 +108,8 @@ func (i *imageGenerationProcessorRouterFilter) ProcessRequestBody(ctx context.Co
107108
return nil, fmt.Errorf("failed to parse request body: %w", err)
108109
}
109110

110-
// Handle streaming requests for consistency with chat completion
111-
// Unlike chat completion which uses StreamOptions.IncludeUsage for cost tracking,
112-
// image generation streaming (only supported by gpt-image-1) doesn't have
113-
// the same usage tracking mechanism, but we still need to detect and flag streaming requests
114-
isStreamingRequest := body.Stream != nil && *body.Stream && body.Model == openai.ModelGPTImage1
111+
// OpenAI SDK doesn't expose a generic Stream flag for image generation; keep false for now.
112+
isStreamingRequest := false
115113

116114
i.requestHeaders[i.config.modelNameHeaderKey] = model
117115

@@ -171,7 +169,7 @@ type imageGenerationProcessorUpstreamFilter struct {
171169
handler backendauth.Handler
172170
headerMutator *headermutator.HeaderMutator
173171
originalRequestBodyRaw []byte
174-
originalRequestBody *openai.ImageGenerationRequest
172+
originalRequestBody *openaisdk.ImageGenerateParams
175173
translator translator.ImageGenerationTranslator
176174
// onRetry is true if this is a retry request at the upstream filter.
177175
onRetry bool
@@ -215,18 +213,19 @@ func (i *imageGenerationProcessorUpstreamFilter) ProcessRequestHeaders(ctx conte
215213

216214
// Start tracking metrics for this request.
217215
i.metrics.StartRequest(i.requestHeaders)
218-
i.metrics.SetModel(i.requestHeaders[i.config.modelNameHeaderKey])
216+
// For image generation we generally expect request and response model to match.
217+
// If a backend override occurs, response model may be updated downstream via headers but we keep
218+
// metrics consistent with the selected model header.
219+
m := i.requestHeaders[i.config.modelNameHeaderKey]
220+
i.metrics.SetRequestModel(internalapi.RequestModel(m))
221+
i.metrics.SetResponseModel(internalapi.ResponseModel(m))
219222

220223
// We force the body mutation in the following cases:
221224
// * The request is a retry request because the body mutation might have happened the previous iteration.
222-
var headerMutation *extprocv3.HeaderMutation
223-
var bodyMutation *extprocv3.BodyMutation
224-
if i.translator != nil {
225-
forceBodyMutation := i.onRetry
226-
headerMutation, bodyMutation, err = i.translator.RequestBody(i.originalRequestBodyRaw, i.originalRequestBody, forceBodyMutation)
227-
if err != nil {
228-
return nil, fmt.Errorf("failed to transform request: %w", err)
229-
}
225+
forceBodyMutation := i.onRetry
226+
headerMutation, bodyMutation, err := i.translator.RequestBody(i.originalRequestBodyRaw, i.originalRequestBody, forceBodyMutation)
227+
if err != nil {
228+
return nil, fmt.Errorf("failed to transform request: %w", err)
230229
}
231230
if headerMutation == nil {
232231
headerMutation = &extprocv3.HeaderMutation{}
@@ -284,12 +283,19 @@ func (i *imageGenerationProcessorUpstreamFilter) ProcessResponseHeaders(ctx cont
284283
if enc := i.responseHeaders["content-encoding"]; enc != "" {
285284
i.responseEncoding = enc
286285
}
287-
var headerMutation *extprocv3.HeaderMutation
288-
if i.translator != nil {
289-
headerMutation, err = i.translator.ResponseHeaders(i.responseHeaders)
290-
if err != nil {
291-
return nil, fmt.Errorf("failed to transform response headers: %w", err)
292-
}
286+
287+
// Debug logging for response headers
288+
if i.logger.Enabled(ctx, slog.LevelDebug) {
289+
i.logger.Debug("response headers received",
290+
slog.String("content-type", i.responseHeaders["content-type"]),
291+
slog.String("content-encoding", i.responseHeaders["content-encoding"]),
292+
slog.String("status", i.responseHeaders[":status"]),
293+
slog.String("response_encoding", i.responseEncoding))
294+
}
295+
296+
headerMutation, err := i.translator.ResponseHeaders(i.responseHeaders)
297+
if err != nil {
298+
return nil, fmt.Errorf("failed to transform response headers: %w", err)
293299
}
294300

295301
var mode *extprocv3http.ProcessingMode
@@ -318,21 +324,48 @@ func (i *imageGenerationProcessorUpstreamFilter) ProcessResponseBody(ctx context
318324
}
319325
}()
320326

327+
// Debug logging for raw response body
328+
if i.logger.Enabled(ctx, slog.LevelDebug) {
329+
bodyPreview := string(body.Body)
330+
if len(bodyPreview) > 100 {
331+
bodyPreview = bodyPreview[:100] + "..."
332+
}
333+
i.logger.Debug("raw response body received",
334+
slog.Int("body_length", len(body.Body)),
335+
slog.String("body_preview", bodyPreview),
336+
slog.String("content_encoding", i.responseEncoding),
337+
slog.Bool("end_of_stream", body.EndOfStream))
338+
}
339+
340+
// Decompress the body if needed using common utility.
341+
decodingResult, err := decodeContentIfNeeded(body.Body, i.responseEncoding)
342+
if err != nil {
343+
return nil, err
344+
}
345+
346+
// Debug logging for decoded response body
347+
if i.logger.Enabled(ctx, slog.LevelDebug) {
348+
decodedBytes, _ := io.ReadAll(decodingResult.reader)
349+
decodedPreview := string(decodedBytes)
350+
if len(decodedPreview) > 100 {
351+
decodedPreview = decodedPreview[:100] + "..."
352+
}
353+
i.logger.Debug("decoded response body",
354+
slog.Int("decoded_length", len(decodedBytes)),
355+
slog.String("decoded_preview", decodedPreview),
356+
slog.Bool("was_encoded", decodingResult.isEncoded))
357+
358+
// Reset reader for translator
359+
decodingResult.reader = bytes.NewReader(decodedBytes)
360+
}
361+
321362
// Assume all responses have a valid status code header.
322363
if code, _ := strconv.Atoi(i.responseHeaders[":status"]); !isGoodStatusCode(code) {
323364
var headerMutation *extprocv3.HeaderMutation
324365
var bodyMutation *extprocv3.BodyMutation
325-
if i.translator != nil {
326-
headerMutation, bodyMutation, err = i.translator.ResponseError(i.responseHeaders, bytes.NewReader(body.Body))
327-
if err != nil {
328-
return nil, fmt.Errorf("failed to transform response error: %w", err)
329-
}
330-
}
331-
if headerMutation == nil {
332-
headerMutation = &extprocv3.HeaderMutation{}
333-
}
334-
if bodyMutation == nil {
335-
bodyMutation = &extprocv3.BodyMutation{}
366+
headerMutation, bodyMutation, err = i.translator.ResponseError(i.responseHeaders, decodingResult.reader)
367+
if err != nil {
368+
return nil, fmt.Errorf("failed to transform response error: %w", err)
336369
}
337370
if i.span != nil {
338371
b := bodyMutation.GetBody()
@@ -360,14 +393,13 @@ func (i *imageGenerationProcessorUpstreamFilter) ProcessResponseBody(ctx context
360393
var bodyMutation *extprocv3.BodyMutation
361394
var tokenUsage translator.LLMTokenUsage
362395
var imageMetadata translator.ImageGenerationMetadata
363-
if i.translator != nil {
364-
headerMutation, bodyMutation, tokenUsage, imageMetadata, err = i.translator.ResponseBody(i.responseHeaders, bytes.NewReader(body.Body), body.EndOfStream)
365-
if err != nil {
366-
return nil, fmt.Errorf("failed to transform response: %w", err)
367-
}
396+
headerMutation, bodyMutation, tokenUsage, imageMetadata, err = i.translator.ResponseBody(i.responseHeaders, decodingResult.reader, body.EndOfStream)
397+
if err != nil {
398+
return nil, fmt.Errorf("failed to transform response: %w", err)
368399
}
369400

370-
// TODO: Implement gzip handling when bodyMutation is non-nil and response is gzipped
401+
// Remove content-encoding header if original body encoded but was mutated in the processor.
402+
headerMutation = removeContentEncodingIfNeeded(headerMutation, bodyMutation, decodingResult.isEncoded)
371403

372404
resp := &extprocv3.ProcessingResponse{
373405
Response: &extprocv3.ProcessingResponse_ResponseBody{
@@ -385,14 +417,14 @@ func (i *imageGenerationProcessorUpstreamFilter) ProcessResponseBody(ctx context
385417
i.costs.OutputTokens += tokenUsage.OutputTokens
386418
i.costs.TotalTokens += tokenUsage.TotalTokens
387419

388-
// Update metrics with token usage.
389-
i.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, tokenUsage.OutputTokens, tokenUsage.TotalTokens, i.requestHeaders)
420+
// Update metrics with token usage (input/output only per OTEL spec).
421+
i.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, tokenUsage.OutputTokens, i.requestHeaders)
390422

391423
// Record image generation metrics
392424
i.metrics.RecordImageGeneration(ctx, imageMetadata.ImageCount, imageMetadata.Model, imageMetadata.Size, i.requestHeaders)
393425

394426
if body.EndOfStream && len(i.config.requestCosts) > 0 {
395-
metadata, err := buildDynamicMetadata(i.config, &i.costs, i.requestHeaders, i.modelNameOverride, i.backendName)
427+
metadata, err := buildDynamicMetadata(i.config, &i.costs, i.requestHeaders, i.backendName)
396428
if err != nil {
397429
return nil, fmt.Errorf("failed to build dynamic metadata: %w", err)
398430
}
@@ -424,6 +456,16 @@ func (i *imageGenerationProcessorUpstreamFilter) SetBackend(ctx context.Context,
424456
if err = i.selectTranslator(b.Schema); err != nil {
425457
return fmt.Errorf("failed to select translator: %w", err)
426458
}
459+
460+
// Debug logging for backend selection
461+
if i.logger.Enabled(ctx, slog.LevelDebug) {
462+
i.logger.Debug("backend selected for image generation",
463+
slog.String("backend_name", b.Name),
464+
slog.String("schema_name", string(b.Schema.Name)),
465+
slog.String("schema_version", b.Schema.Version),
466+
slog.String("model_override", i.modelNameOverride),
467+
slog.Bool("translator_set", i.translator != nil))
468+
}
427469
i.handler = backendHandler
428470
i.headerMutator = headermutator.NewHeaderMutator(b.HeaderMutation, rp.requestHeaders)
429471
// Sync header with backend model so header-derived labels/CEL use the actual model.
@@ -435,10 +477,8 @@ func (i *imageGenerationProcessorUpstreamFilter) SetBackend(ctx context.Context,
435477
i.onRetry = rp.upstreamFilterCount > 1
436478

437479
// Set streaming flag for GPT-Image-1 requests
438-
i.stream = i.originalRequestBody != nil &&
439-
i.originalRequestBody.Stream != nil &&
440-
*i.originalRequestBody.Stream &&
441-
i.originalRequestBody.Model == openai.ModelGPTImage1
480+
// Image generation streaming not supported in current SDK params; keep false.
481+
i.stream = false
442482

443483
if isEndpointPicker {
444484
if i.logger.Enabled(ctx, slog.LevelDebug) {
@@ -450,10 +490,10 @@ func (i *imageGenerationProcessorUpstreamFilter) SetBackend(ctx context.Context,
450490
return
451491
}
452492

453-
func parseOpenAIImageGenerationBody(body *extprocv3.HttpBody) (modelName string, rb *openai.ImageGenerationRequest, err error) {
454-
var openAIReq openai.ImageGenerationRequest
493+
func parseOpenAIImageGenerationBody(body *extprocv3.HttpBody) (modelName string, rb *openaisdk.ImageGenerateParams, err error) {
494+
var openAIReq openaisdk.ImageGenerateParams
455495
if err := json.Unmarshal(body.Body, &openAIReq); err != nil {
456496
return "", nil, fmt.Errorf("failed to unmarshal body: %w", err)
457497
}
458-
return openAIReq.Model, &openAIReq, nil
498+
return string(openAIReq.Model), &openAIReq, nil
459499
}

internal/extproc/imagegeneration_processor_test.go

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package extproc
77

88
import (
9-
"context"
109
"encoding/json"
1110
"log/slog"
1211
"testing"
@@ -15,12 +14,11 @@ import (
1514
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
1615
"github.com/stretchr/testify/require"
1716

18-
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
1917
"github.com/envoyproxy/ai-gateway/internal/extproc/translator"
2018
"github.com/envoyproxy/ai-gateway/internal/filterapi"
2119
"github.com/envoyproxy/ai-gateway/internal/llmcostcel"
22-
"github.com/envoyproxy/ai-gateway/internal/metrics"
2320
tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api"
21+
openaisdk "github.com/openai/openai-go/v2"
2422
)
2523

2624
func TestImageGeneration_Schema(t *testing.T) {
@@ -146,7 +144,7 @@ func Test_imageGenerationProcessorUpstreamFilter_ProcessRequestHeaders(t *testin
146144
logger: slog.Default(),
147145
metrics: mm,
148146
originalRequestBodyRaw: imageGenerationBodyFromModel(t, "dall-e-3"),
149-
originalRequestBody: &openai.ImageGenerationRequest{Model: "dall-e-3", Prompt: "a cat"},
147+
originalRequestBody: &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModel("dall-e-3"), Prompt: "a cat"},
150148
handler: &mockBackendAuthHandler{},
151149
}
152150
resp, err := p.ProcessRequestHeaders(t.Context(), nil)
@@ -178,7 +176,7 @@ func Test_imageGenerationProcessorUpstreamFilter_SetBackend(t *testing.T) {
178176
mm.RequireSelectedBackend(t, "some-backend")
179177

180178
// Supported OpenAI schema.
181-
rp := &imageGenerationProcessorRouterFilter{originalRequestBody: &openai.ImageGenerationRequest{}}
179+
rp := &imageGenerationProcessorRouterFilter{originalRequestBody: &openaisdk.ImageGenerateParams{}}
182180
p2 := &imageGenerationProcessorUpstreamFilter{
183181
config: &processorConfig{modelNameHeaderKey: "x-model-name"},
184182
requestHeaders: map[string]string{"x-model-name": "dall-e-2"},
@@ -215,57 +213,7 @@ func TestImageGeneration_ParseBody(t *testing.T) {
215213
// imageGenerationBodyFromModel returns a minimal valid image generation request for tests.
216214
func imageGenerationBodyFromModel(t *testing.T, model string) []byte {
217215
t.Helper()
218-
b, err := json.Marshal(&openai.ImageGenerationRequest{Model: model, Prompt: "a cat"})
216+
b, err := json.Marshal(&openaisdk.ImageGenerateParams{Model: openaisdk.ImageModel(model), Prompt: "a cat"})
219217
require.NoError(t, err)
220218
return b
221219
}
222-
223-
// mockImageGenerationMetrics implements [metrics.ImageGenerationMetrics] for testing.
224-
type mockImageGenerationMetrics struct {
225-
requestSuccessCount int
226-
requestErrorCount int
227-
model string
228-
backend string
229-
tokenUsageCount int
230-
}
231-
232-
func (m *mockImageGenerationMetrics) StartRequest(map[string]string) {}
233-
func (m *mockImageGenerationMetrics) SetModel(model string) { m.model = model }
234-
func (m *mockImageGenerationMetrics) SetBackend(b *filterapi.Backend) { m.backend = b.Name }
235-
func (m *mockImageGenerationMetrics) RecordTokenUsage(_ context.Context, _ uint32, _ uint32, _ uint32, _ map[string]string) {
236-
m.tokenUsageCount++
237-
}
238-
func (m *mockImageGenerationMetrics) RecordRequestCompletion(_ context.Context, success bool, _ map[string]string) {
239-
if success {
240-
m.requestSuccessCount++
241-
} else {
242-
m.requestErrorCount++
243-
}
244-
}
245-
func (m *mockImageGenerationMetrics) RecordImageGeneration(_ context.Context, _ int, _ string, _ string, _ map[string]string) {
246-
}
247-
248-
func (m *mockImageGenerationMetrics) RequireRequestFailure(t *testing.T) {
249-
require.Equal(t, 0, m.requestSuccessCount)
250-
require.Equal(t, 1, m.requestErrorCount)
251-
}
252-
func (m *mockImageGenerationMetrics) RequireRequestNotCompleted(t *testing.T) {
253-
require.Equal(t, 0, m.requestSuccessCount)
254-
require.Equal(t, 0, m.requestErrorCount)
255-
}
256-
func (m *mockImageGenerationMetrics) RequireRequestSuccess(t *testing.T) {
257-
require.Equal(t, 1, m.requestSuccessCount)
258-
require.Equal(t, 0, m.requestErrorCount)
259-
}
260-
func (m *mockImageGenerationMetrics) RequireSelectedModel(t *testing.T, model string) {
261-
require.Equal(t, model, m.model)
262-
}
263-
func (m *mockImageGenerationMetrics) RequireSelectedBackend(t *testing.T, backend string) {
264-
require.Equal(t, backend, m.backend)
265-
}
266-
func (m *mockImageGenerationMetrics) RequireTokensRecorded(t *testing.T, count int) {
267-
require.Equal(t, count, m.tokenUsageCount)
268-
}
269-
270-
// Ensure mock implements the interface at compile-time.
271-
var _ metrics.ImageGenerationMetrics = &mockImageGenerationMetrics{}

0 commit comments

Comments
 (0)