Skip to content

Commit 9e45c62

Browse files
feat: align image generation processor with embeddings model handling patterns
- Update modelNameOverride field type from string to internalapi.ModelNameOverride for type consistency - Implement proper original model tracking by calling SetOriginalModel with request body model before overrides Signed-off-by: Hrushikesh Patil <[email protected]>
1 parent d9da239 commit 9e45c62

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

internal/extproc/imagegeneration_processor.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package extproc
77

88
import (
99
"bytes"
10+
"cmp"
1011
"context"
1112
"encoding/json"
1213
"fmt"
@@ -164,7 +165,7 @@ type imageGenerationProcessorUpstreamFilter struct {
164165
requestHeaders map[string]string
165166
responseHeaders map[string]string
166167
responseEncoding string
167-
modelNameOverride string
168+
modelNameOverride internalapi.ModelNameOverride
168169
backendName string
169170
handler backendauth.Handler
170171
headerMutator *headermutator.HeaderMutator
@@ -188,7 +189,7 @@ type imageGenerationProcessorUpstreamFilter struct {
188189
func (i *imageGenerationProcessorUpstreamFilter) selectTranslator(out filterapi.VersionedAPISchema) error {
189190
switch out.Name {
190191
case filterapi.APISchemaOpenAI:
191-
i.translator = translator.NewImageGenerationOpenAIToOpenAITranslator(out.Version, i.modelNameOverride)
192+
i.translator = translator.NewImageGenerationOpenAIToOpenAITranslator(out.Version, i.modelNameOverride, i.span)
192193
case filterapi.APISchemaAWSBedrock:
193194
// i.translator = translator.NewImageGenerationOpenAIToAWSBedrockTranslator(i.modelNameOverride)
194195
i.translator = nil // Placeholder
@@ -213,12 +214,12 @@ func (i *imageGenerationProcessorUpstreamFilter) ProcessRequestHeaders(ctx conte
213214

214215
// Start tracking metrics for this request.
215216
i.metrics.StartRequest(i.requestHeaders)
216-
// For image generation we generally expect request and response model to match.
217-
// If a backend override occurs, response model may be updated downstream via headers but we keep
218-
// metrics consistent with the selected model header.
219-
m := i.requestHeaders[i.config.modelNameHeaderKey]
220-
i.metrics.SetRequestModel(internalapi.RequestModel(m))
221-
i.metrics.SetResponseModel(internalapi.ResponseModel(m))
217+
// Set the original model from the request body before any overrides
218+
i.metrics.SetOriginalModel(internalapi.OriginalModel(i.originalRequestBody.Model))
219+
// Set the request model for metrics from the original model or override if applied.
220+
reqModel := cmp.Or(i.requestHeaders[i.config.modelNameHeaderKey], string(i.originalRequestBody.Model))
221+
i.metrics.SetRequestModel(internalapi.RequestModel(reqModel))
222+
i.metrics.SetResponseModel(internalapi.ResponseModel(reqModel))
222223

223224
// We force the body mutation in the following cases:
224225
// * The request is a retry request because the body mutation might have happened the previous iteration.
@@ -470,7 +471,9 @@ func (i *imageGenerationProcessorUpstreamFilter) SetBackend(ctx context.Context,
470471
i.headerMutator = headermutator.NewHeaderMutator(b.HeaderMutation, rp.requestHeaders)
471472
// Sync header with backend model so header-derived labels/CEL use the actual model.
472473
if i.modelNameOverride != "" {
473-
i.requestHeaders[i.config.modelNameHeaderKey] = i.modelNameOverride
474+
i.requestHeaders[i.config.modelNameHeaderKey] = string(i.modelNameOverride)
475+
// Update metrics with the overridden model
476+
i.metrics.SetRequestModel(internalapi.RequestModel(i.modelNameOverride))
474477
}
475478
i.originalRequestBody = rp.originalRequestBody
476479
i.originalRequestBodyRaw = rp.originalRequestBodyRaw

internal/extproc/mocks_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ type mockImageGenerationMetrics struct {
407407
}
408408

409409
func (m *mockImageGenerationMetrics) StartRequest(map[string]string) {}
410+
func (m *mockImageGenerationMetrics) SetOriginalModel(originalModel string) {
411+
m.model = originalModel
412+
}
410413
func (m *mockImageGenerationMetrics) SetRequestModel(requestModel string) {
411414
m.model = requestModel
412415
}

0 commit comments

Comments
 (0)