Skip to content

Commit faa32e7

Browse files
test: enhance image generation processor tests
Signed-off-by: Hrushikesh Patil <[email protected]>
1 parent 74e990c commit faa32e7

File tree

2 files changed

+251
-33
lines changed

2 files changed

+251
-33
lines changed

internal/extproc/imagegeneration_processor_test.go

Lines changed: 243 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
package extproc
77

88
import (
9+
"context"
910
"encoding/json"
11+
"errors"
12+
"io"
1013
"log/slog"
1114
"testing"
1215

@@ -38,6 +41,64 @@ func TestImageGeneration_Schema(t *testing.T) {
3841
})
3942
}
4043

44+
func Test_imageGenerationProcessorUpstreamFilter_SelectTranslator(t *testing.T) {
45+
c := &imageGenerationProcessorUpstreamFilter{}
46+
t.Run("unsupported", func(t *testing.T) {
47+
err := c.selectTranslator(filterapi.VersionedAPISchema{Name: "Bar", Version: "v123"})
48+
require.ErrorContains(t, err, "unsupported API schema: backend={Bar v123}")
49+
})
50+
t.Run("supported openai", func(t *testing.T) {
51+
err := c.selectTranslator(filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI})
52+
require.NoError(t, err)
53+
require.NotNil(t, c.translator)
54+
})
55+
t.Run("supported aws bedrock", func(t *testing.T) {
56+
err := c.selectTranslator(filterapi.VersionedAPISchema{Name: filterapi.APISchemaAWSBedrock})
57+
require.NoError(t, err)
58+
require.Nil(t, c.translator) // Placeholder implementation
59+
})
60+
}
61+
62+
type mockImageGenerationTracer struct {
63+
tracing.NoopImageGenerationTracer
64+
startSpanCalled bool
65+
returnedSpan tracing.ImageGenerationSpan
66+
}
67+
68+
func (m *mockImageGenerationTracer) StartSpanAndInjectHeaders(_ context.Context, _ map[string]string, headerMutation *extprocv3.HeaderMutation, _ *openaisdk.ImageGenerateParams, _ []byte) tracing.ImageGenerationSpan {
69+
m.startSpanCalled = true
70+
headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{
71+
Header: &corev3.HeaderValue{
72+
Key: "tracing-header",
73+
Value: "1",
74+
},
75+
})
76+
if m.returnedSpan != nil {
77+
return m.returnedSpan
78+
}
79+
return nil
80+
}
81+
82+
// Mock span for image generation tests
83+
type mockImageGenerationSpan struct {
84+
endSpanCalled bool
85+
errorStatus int
86+
errBody string
87+
}
88+
89+
func (m *mockImageGenerationSpan) EndSpan() {
90+
m.endSpanCalled = true
91+
}
92+
93+
func (m *mockImageGenerationSpan) EndSpanOnError(status int, body []byte) {
94+
m.errorStatus = status
95+
m.errBody = string(body)
96+
}
97+
98+
func (m *mockImageGenerationSpan) RecordResponse(resp *openaisdk.ImagesResponse) {
99+
// Mock implementation
100+
}
101+
41102
func Test_imageGenerationProcessorRouterFilter_ProcessRequestBody(t *testing.T) {
42103
t.Run("body parser error", func(t *testing.T) {
43104
p := &imageGenerationProcessorRouterFilter{}
@@ -69,83 +130,188 @@ func Test_imageGenerationProcessorRouterFilter_ProcessRequestBody(t *testing.T)
69130
require.Equal(t, "x-ai-eg-original-path", setHeaders[1].Header.Key)
70131
require.Equal(t, "/v1/images/generations", string(setHeaders[1].Header.RawValue))
71132
})
133+
134+
t.Run("span creation", func(t *testing.T) {
135+
headers := map[string]string{":path": "/v1/images/generations"}
136+
const modelKey = "x-ai-gateway-model-key"
137+
span := &mockImageGenerationSpan{}
138+
mockTracerInstance := &mockImageGenerationTracer{returnedSpan: span}
139+
140+
p := &imageGenerationProcessorRouterFilter{
141+
config: &processorConfig{modelNameHeaderKey: modelKey},
142+
requestHeaders: headers,
143+
logger: slog.Default(),
144+
tracer: mockTracerInstance,
145+
}
146+
147+
body := imageGenerationBodyFromModel(t, "dall-e-3")
148+
resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: body})
149+
require.NoError(t, err)
150+
require.NotNil(t, resp)
151+
require.True(t, mockTracerInstance.startSpanCalled)
152+
require.Equal(t, span, p.span)
153+
154+
// Verify headers are injected.
155+
re, ok := resp.Response.(*extprocv3.ProcessingResponse_RequestBody)
156+
require.True(t, ok)
157+
headerMutation := re.RequestBody.GetResponse().GetHeaderMutation()
158+
require.Contains(t, headerMutation.SetHeaders, &corev3.HeaderValueOption{
159+
Header: &corev3.HeaderValue{
160+
Key: "tracing-header",
161+
Value: "1",
162+
},
163+
})
164+
})
72165
}
73166

74167
func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseHeaders(t *testing.T) {
75-
t.Run("ok passthrough", func(t *testing.T) {
76-
p := &imageGenerationProcessorUpstreamFilter{metrics: &mockImageGenerationMetrics{}}
77-
res, err := p.ProcessResponseHeaders(t.Context(), &corev3.HeaderMap{Headers: []*corev3.HeaderValue{}})
168+
t.Run("error translation", func(t *testing.T) {
169+
mm := &mockImageGenerationMetrics{}
170+
mt := &mockImageGenerationTranslator{t: t, expHeaders: make(map[string]string)}
171+
p := &imageGenerationProcessorUpstreamFilter{
172+
translator: mt,
173+
metrics: mm,
174+
logger: slog.Default(),
175+
}
176+
mt.retErr = errors.New("test error")
177+
_, err := p.ProcessResponseHeaders(t.Context(), nil)
178+
require.ErrorContains(t, err, "test error")
179+
mm.RequireRequestFailure(t)
180+
})
181+
t.Run("ok", func(t *testing.T) {
182+
inHeaders := &corev3.HeaderMap{
183+
Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}, {Key: "dog", RawValue: []byte("cat")}},
184+
}
185+
expHeaders := map[string]string{"foo": "bar", "dog": "cat"}
186+
mm := &mockImageGenerationMetrics{}
187+
mt := &mockImageGenerationTranslator{t: t, expHeaders: expHeaders}
188+
p := &imageGenerationProcessorUpstreamFilter{
189+
translator: mt,
190+
metrics: mm,
191+
logger: slog.Default(),
192+
}
193+
res, err := p.ProcessResponseHeaders(t.Context(), inHeaders)
78194
require.NoError(t, err)
79-
require.NotNil(t, res)
195+
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
196+
require.Equal(t, mt.retHeaderMutation, commonRes.HeaderMutation)
197+
mm.RequireRequestNotCompleted(t)
80198
})
81199
}
82200

83201
func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) {
84-
t.Run("non-2xx marks failure and returns mutations", func(t *testing.T) {
202+
t.Run("error translation", func(t *testing.T) {
85203
mm := &mockImageGenerationMetrics{}
204+
mt := &mockImageGenerationTranslator{t: t}
86205
p := &imageGenerationProcessorUpstreamFilter{
87-
metrics: mm,
88-
responseHeaders: map[string]string{":status": "500"},
206+
translator: mt,
207+
metrics: mm,
208+
logger: slog.Default(),
89209
}
90-
res, err := p.ProcessResponseBody(t.Context(), &extprocv3.HttpBody{Body: []byte("err"), EndOfStream: true})
91-
require.NoError(t, err)
92-
require.NotNil(t, res)
93-
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response
94-
require.NotNil(t, commonRes.HeaderMutation)
95-
require.NotNil(t, commonRes.BodyMutation)
210+
mt.retErr = errors.New("test error")
211+
_, err := p.ProcessResponseBody(t.Context(), &extprocv3.HttpBody{})
212+
require.ErrorContains(t, err, "test error")
96213
mm.RequireRequestFailure(t)
214+
require.Zero(t, mm.tokenUsageCount)
97215
})
98-
99-
t.Run("200 end-of-stream records success and metadata", func(t *testing.T) {
100-
inBody := &extprocv3.HttpBody{Body: []byte(`{"created":1,"data":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`), EndOfStream: true}
216+
t.Run("ok", func(t *testing.T) {
217+
inBody := &extprocv3.HttpBody{Body: []byte("some-body"), EndOfStream: true}
218+
expBodyMut := &extprocv3.BodyMutation{}
219+
expHeadMut := &extprocv3.HeaderMutation{}
101220
mm := &mockImageGenerationMetrics{}
221+
mt := &mockImageGenerationTranslator{
222+
t: t, expResponseBody: inBody,
223+
retBodyMutation: expBodyMut, retHeaderMutation: expHeadMut,
224+
retUsedToken: translator.LLMTokenUsage{OutputTokens: 123, InputTokens: 1},
225+
}
102226

103-
celProgInt, err := llmcostcel.NewProgram("123")
227+
celProgInt, err := llmcostcel.NewProgram("54321")
228+
require.NoError(t, err)
229+
celProgUint, err := llmcostcel.NewProgram("uint(9999)")
104230
require.NoError(t, err)
105-
106231
p := &imageGenerationProcessorUpstreamFilter{
107-
translator: translator.NewImageGenerationOpenAIToOpenAITranslator("v1", "some_model"),
232+
translator: mt,
233+
logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
108234
metrics: mm,
109-
logger: slog.Default(),
110235
config: &processorConfig{
111-
metadataNamespace: "ai_gateway_llm_ns",
236+
metadataNamespace: "ai_gateway_llm_ns",
237+
modelNameHeaderKey: "x-aigw-model",
112238
requestCosts: []processorConfigRequestCost{
239+
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}},
113240
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}},
114-
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeTotalToken, MetadataKey: "total_token_usage"}},
115-
{celProg: celProgInt, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"}},
241+
{
242+
celProg: celProgInt,
243+
LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"},
244+
},
245+
{
246+
celProg: celProgUint,
247+
LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_uint"},
248+
},
116249
},
117250
},
118-
backendName: "some_backend",
119-
modelNameOverride: "some_model",
251+
requestHeaders: map[string]string{"x-aigw-model": "ai_gateway_llm"},
120252
responseHeaders: map[string]string{":status": "200"},
253+
backendName: "some_backend",
254+
modelNameOverride: "ai_gateway_llm",
121255
}
122256
res, err := p.ProcessResponseBody(t.Context(), inBody)
123257
require.NoError(t, err)
124-
require.NotNil(t, res)
258+
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response
259+
require.Equal(t, expBodyMut, commonRes.BodyMutation)
260+
require.Equal(t, expHeadMut, commonRes.HeaderMutation)
125261
mm.RequireRequestSuccess(t)
262+
require.Equal(t, 124, mm.tokenUsageCount) // 1 input + 123 output
126263

127264
md := res.DynamicMetadata
128265
require.NotNil(t, md)
129-
require.Equal(t, float64(0), md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["input_token_usage"].GetNumberValue())
130-
require.Equal(t, float64(0), md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["total_token_usage"].GetNumberValue())
131-
require.Equal(t, float64(123), md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["cel_int"].GetNumberValue())
266+
require.Equal(t, float64(123), md.Fields["ai_gateway_llm_ns"].
267+
GetStructValue().Fields["output_token_usage"].GetNumberValue())
268+
require.Equal(t, float64(1), md.Fields["ai_gateway_llm_ns"].
269+
GetStructValue().Fields["input_token_usage"].GetNumberValue())
270+
require.Equal(t, float64(54321), md.Fields["ai_gateway_llm_ns"].
271+
GetStructValue().Fields["cel_int"].GetNumberValue())
272+
require.Equal(t, float64(9999), md.Fields["ai_gateway_llm_ns"].
273+
GetStructValue().Fields["cel_uint"].GetNumberValue())
274+
require.Equal(t, "ai_gateway_llm", md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["model_name_override"].GetStringValue())
132275
require.Equal(t, "some_backend", md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["backend_name"].GetStringValue())
133-
require.Equal(t, "some_model", md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["model_name_override"].GetStringValue())
276+
})
277+
278+
// Verify we record failure for non-2xx responses and do it exactly once (defer suppressed).
279+
t.Run("non-2xx status failure once", func(t *testing.T) {
280+
inBody := &extprocv3.HttpBody{Body: []byte("error-body"), EndOfStream: true}
281+
expHeadMut := &extprocv3.HeaderMutation{}
282+
expBodyMut := &extprocv3.BodyMutation{}
283+
mm := &mockImageGenerationMetrics{}
284+
mt := &mockImageGenerationTranslator{t: t, expResponseBody: inBody, retHeaderMutation: expHeadMut, retBodyMutation: expBodyMut}
285+
p := &imageGenerationProcessorUpstreamFilter{
286+
translator: mt,
287+
metrics: mm,
288+
responseHeaders: map[string]string{":status": "500"},
289+
logger: slog.Default(),
290+
}
291+
res, err := p.ProcessResponseBody(t.Context(), inBody)
292+
require.NoError(t, err)
293+
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response
294+
require.Equal(t, expBodyMut, commonRes.BodyMutation)
295+
require.Equal(t, expHeadMut, commonRes.HeaderMutation)
296+
mm.RequireRequestFailure(t)
134297
})
135298
}
136299

137300
func Test_imageGenerationProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) {
138301
t.Run("ok with auth handler and header mutator", func(t *testing.T) {
139302
headers := map[string]string{":path": "/v1/images/generations", "x-model": "dall-e-3"}
140303
mm := &mockImageGenerationMetrics{}
304+
body := &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModel("dall-e-3"), Prompt: "a cat"}
305+
mt := &mockImageGenerationTranslator{t: t, expRequestBody: body}
141306
p := &imageGenerationProcessorUpstreamFilter{
142307
config: &processorConfig{modelNameHeaderKey: "x-model"},
143308
requestHeaders: headers,
144309
logger: slog.Default(),
145310
metrics: mm,
146311
originalRequestBodyRaw: imageGenerationBodyFromModel(t, "dall-e-3"),
147-
originalRequestBody: &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModel("dall-e-3"), Prompt: "a cat"},
312+
originalRequestBody: body,
148313
handler: &mockBackendAuthHandler{},
314+
translator: mt,
149315
}
150316
resp, err := p.ProcessRequestHeaders(t.Context(), nil)
151317
require.NoError(t, err)
@@ -210,6 +376,52 @@ func TestImageGeneration_ParseBody(t *testing.T) {
210376
})
211377
}
212378

379+
// Mock translator for image generation tests
380+
type mockImageGenerationTranslator struct {
381+
t *testing.T
382+
expRequestBody *openaisdk.ImageGenerateParams
383+
expResponseBody *extprocv3.HttpBody
384+
expHeaders map[string]string
385+
expForceRequestBodyMutation bool
386+
retErr error
387+
retHeaderMutation *extprocv3.HeaderMutation
388+
retBodyMutation *extprocv3.BodyMutation
389+
retUsedToken translator.LLMTokenUsage
390+
retImageMetadata translator.ImageGenerationMetadata
391+
}
392+
393+
func (m *mockImageGenerationTranslator) RequestBody(original []byte, req *openaisdk.ImageGenerateParams, forceBodyMutation bool) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, error) {
394+
if m.expRequestBody != nil {
395+
require.Equal(m.t, m.expRequestBody, req)
396+
}
397+
if m.expForceRequestBodyMutation {
398+
require.True(m.t, forceBodyMutation)
399+
}
400+
return m.retHeaderMutation, m.retBodyMutation, m.retErr
401+
}
402+
403+
func (m *mockImageGenerationTranslator) ResponseHeaders(headers map[string]string) (*extprocv3.HeaderMutation, error) {
404+
if m.expHeaders != nil {
405+
for k, v := range m.expHeaders {
406+
require.Equal(m.t, v, headers[k])
407+
}
408+
}
409+
return m.retHeaderMutation, m.retErr
410+
}
411+
412+
func (m *mockImageGenerationTranslator) ResponseBody(headers map[string]string, body io.Reader, endOfStream bool) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, translator.LLMTokenUsage, translator.ImageGenerationMetadata, error) {
413+
if m.expResponseBody != nil {
414+
bodyBytes, _ := io.ReadAll(body)
415+
require.Equal(m.t, m.expResponseBody.Body, bodyBytes)
416+
require.Equal(m.t, m.expResponseBody.EndOfStream, endOfStream)
417+
}
418+
return m.retHeaderMutation, m.retBodyMutation, m.retUsedToken, m.retImageMetadata, m.retErr
419+
}
420+
421+
func (m *mockImageGenerationTranslator) ResponseError(headers map[string]string, body io.Reader) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, error) {
422+
return m.retHeaderMutation, m.retBodyMutation, m.retErr
423+
}
424+
213425
// imageGenerationBodyFromModel returns a minimal valid image generation request for tests.
214426
func imageGenerationBodyFromModel(t *testing.T, model string) []byte {
215427
t.Helper()

internal/extproc/mocks_test.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,12 +407,18 @@ type mockImageGenerationMetrics struct {
407407
}
408408

409409
func (m *mockImageGenerationMetrics) StartRequest(map[string]string) {}
410+
func (m *mockImageGenerationMetrics) SetRequestModel(requestModel string) {
411+
m.model = requestModel
412+
}
413+
func (m *mockImageGenerationMetrics) SetResponseModel(responseModel string) {
414+
m.model = responseModel
415+
}
410416
func (m *mockImageGenerationMetrics) SetModel(requestModel, responseModel string) {
411417
m.model = responseModel
412418
}
413419
func (m *mockImageGenerationMetrics) SetBackend(b *filterapi.Backend) { m.backend = b.Name }
414-
func (m *mockImageGenerationMetrics) RecordTokenUsage(_ context.Context, _ uint32, _ uint32, _ map[string]string) {
415-
m.tokenUsageCount++
420+
func (m *mockImageGenerationMetrics) RecordTokenUsage(_ context.Context, input, output uint32, _ map[string]string) {
421+
m.tokenUsageCount += int(input + output)
416422
}
417423
func (m *mockImageGenerationMetrics) RecordRequestCompletion(_ context.Context, success bool, _ map[string]string) {
418424
if success {

0 commit comments

Comments
 (0)