|
6 | 6 | package extproc |
7 | 7 |
|
8 | 8 | import ( |
| 9 | + "context" |
9 | 10 | "encoding/json" |
| 11 | + "errors" |
| 12 | + "io" |
10 | 13 | "log/slog" |
11 | 14 | "testing" |
12 | 15 |
|
@@ -38,6 +41,64 @@ func TestImageGeneration_Schema(t *testing.T) { |
38 | 41 | }) |
39 | 42 | } |
40 | 43 |
|
| 44 | +func Test_imageGenerationProcessorUpstreamFilter_SelectTranslator(t *testing.T) { |
| 45 | + c := &imageGenerationProcessorUpstreamFilter{} |
| 46 | + t.Run("unsupported", func(t *testing.T) { |
| 47 | + err := c.selectTranslator(filterapi.VersionedAPISchema{Name: "Bar", Version: "v123"}) |
| 48 | + require.ErrorContains(t, err, "unsupported API schema: backend={Bar v123}") |
| 49 | + }) |
| 50 | + t.Run("supported openai", func(t *testing.T) { |
| 51 | + err := c.selectTranslator(filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}) |
| 52 | + require.NoError(t, err) |
| 53 | + require.NotNil(t, c.translator) |
| 54 | + }) |
| 55 | + t.Run("supported aws bedrock", func(t *testing.T) { |
| 56 | + err := c.selectTranslator(filterapi.VersionedAPISchema{Name: filterapi.APISchemaAWSBedrock}) |
| 57 | + require.NoError(t, err) |
| 58 | + require.Nil(t, c.translator) // Placeholder implementation |
| 59 | + }) |
| 60 | +} |
| 61 | + |
| 62 | +type mockImageGenerationTracer struct { |
| 63 | + tracing.NoopImageGenerationTracer |
| 64 | + startSpanCalled bool |
| 65 | + returnedSpan tracing.ImageGenerationSpan |
| 66 | +} |
| 67 | + |
| 68 | +func (m *mockImageGenerationTracer) StartSpanAndInjectHeaders(_ context.Context, _ map[string]string, headerMutation *extprocv3.HeaderMutation, _ *openaisdk.ImageGenerateParams, _ []byte) tracing.ImageGenerationSpan { |
| 69 | + m.startSpanCalled = true |
| 70 | + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ |
| 71 | + Header: &corev3.HeaderValue{ |
| 72 | + Key: "tracing-header", |
| 73 | + Value: "1", |
| 74 | + }, |
| 75 | + }) |
| 76 | + if m.returnedSpan != nil { |
| 77 | + return m.returnedSpan |
| 78 | + } |
| 79 | + return nil |
| 80 | +} |
| 81 | + |
| 82 | +// Mock span for image generation tests |
| 83 | +type mockImageGenerationSpan struct { |
| 84 | + endSpanCalled bool |
| 85 | + errorStatus int |
| 86 | + errBody string |
| 87 | +} |
| 88 | + |
| 89 | +func (m *mockImageGenerationSpan) EndSpan() { |
| 90 | + m.endSpanCalled = true |
| 91 | +} |
| 92 | + |
| 93 | +func (m *mockImageGenerationSpan) EndSpanOnError(status int, body []byte) { |
| 94 | + m.errorStatus = status |
| 95 | + m.errBody = string(body) |
| 96 | +} |
| 97 | + |
| 98 | +func (m *mockImageGenerationSpan) RecordResponse(resp *openaisdk.ImagesResponse) { |
| 99 | + // Mock implementation |
| 100 | +} |
| 101 | + |
41 | 102 | func Test_imageGenerationProcessorRouterFilter_ProcessRequestBody(t *testing.T) { |
42 | 103 | t.Run("body parser error", func(t *testing.T) { |
43 | 104 | p := &imageGenerationProcessorRouterFilter{} |
@@ -69,83 +130,188 @@ func Test_imageGenerationProcessorRouterFilter_ProcessRequestBody(t *testing.T) |
69 | 130 | require.Equal(t, "x-ai-eg-original-path", setHeaders[1].Header.Key) |
70 | 131 | require.Equal(t, "/v1/images/generations", string(setHeaders[1].Header.RawValue)) |
71 | 132 | }) |
| 133 | + |
| 134 | + t.Run("span creation", func(t *testing.T) { |
| 135 | + headers := map[string]string{":path": "/v1/images/generations"} |
| 136 | + const modelKey = "x-ai-gateway-model-key" |
| 137 | + span := &mockImageGenerationSpan{} |
| 138 | + mockTracerInstance := &mockImageGenerationTracer{returnedSpan: span} |
| 139 | + |
| 140 | + p := &imageGenerationProcessorRouterFilter{ |
| 141 | + config: &processorConfig{modelNameHeaderKey: modelKey}, |
| 142 | + requestHeaders: headers, |
| 143 | + logger: slog.Default(), |
| 144 | + tracer: mockTracerInstance, |
| 145 | + } |
| 146 | + |
| 147 | + body := imageGenerationBodyFromModel(t, "dall-e-3") |
| 148 | + resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: body}) |
| 149 | + require.NoError(t, err) |
| 150 | + require.NotNil(t, resp) |
| 151 | + require.True(t, mockTracerInstance.startSpanCalled) |
| 152 | + require.Equal(t, span, p.span) |
| 153 | + |
| 154 | + // Verify headers are injected. |
| 155 | + re, ok := resp.Response.(*extprocv3.ProcessingResponse_RequestBody) |
| 156 | + require.True(t, ok) |
| 157 | + headerMutation := re.RequestBody.GetResponse().GetHeaderMutation() |
| 158 | + require.Contains(t, headerMutation.SetHeaders, &corev3.HeaderValueOption{ |
| 159 | + Header: &corev3.HeaderValue{ |
| 160 | + Key: "tracing-header", |
| 161 | + Value: "1", |
| 162 | + }, |
| 163 | + }) |
| 164 | + }) |
72 | 165 | } |
73 | 166 |
|
74 | 167 | func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseHeaders(t *testing.T) { |
75 | | - t.Run("ok passthrough", func(t *testing.T) { |
76 | | - p := &imageGenerationProcessorUpstreamFilter{metrics: &mockImageGenerationMetrics{}} |
77 | | - res, err := p.ProcessResponseHeaders(t.Context(), &corev3.HeaderMap{Headers: []*corev3.HeaderValue{}}) |
| 168 | + t.Run("error translation", func(t *testing.T) { |
| 169 | + mm := &mockImageGenerationMetrics{} |
| 170 | + mt := &mockImageGenerationTranslator{t: t, expHeaders: make(map[string]string)} |
| 171 | + p := &imageGenerationProcessorUpstreamFilter{ |
| 172 | + translator: mt, |
| 173 | + metrics: mm, |
| 174 | + logger: slog.Default(), |
| 175 | + } |
| 176 | + mt.retErr = errors.New("test error") |
| 177 | + _, err := p.ProcessResponseHeaders(t.Context(), nil) |
| 178 | + require.ErrorContains(t, err, "test error") |
| 179 | + mm.RequireRequestFailure(t) |
| 180 | + }) |
| 181 | + t.Run("ok", func(t *testing.T) { |
| 182 | + inHeaders := &corev3.HeaderMap{ |
| 183 | + Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}, {Key: "dog", RawValue: []byte("cat")}}, |
| 184 | + } |
| 185 | + expHeaders := map[string]string{"foo": "bar", "dog": "cat"} |
| 186 | + mm := &mockImageGenerationMetrics{} |
| 187 | + mt := &mockImageGenerationTranslator{t: t, expHeaders: expHeaders} |
| 188 | + p := &imageGenerationProcessorUpstreamFilter{ |
| 189 | + translator: mt, |
| 190 | + metrics: mm, |
| 191 | + logger: slog.Default(), |
| 192 | + } |
| 193 | + res, err := p.ProcessResponseHeaders(t.Context(), inHeaders) |
78 | 194 | require.NoError(t, err) |
79 | | - require.NotNil(t, res) |
| 195 | + commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response |
| 196 | + require.Equal(t, mt.retHeaderMutation, commonRes.HeaderMutation) |
| 197 | + mm.RequireRequestNotCompleted(t) |
80 | 198 | }) |
81 | 199 | } |
82 | 200 |
|
83 | 201 | func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) { |
84 | | - t.Run("non-2xx marks failure and returns mutations", func(t *testing.T) { |
| 202 | + t.Run("error translation", func(t *testing.T) { |
85 | 203 | mm := &mockImageGenerationMetrics{} |
| 204 | + mt := &mockImageGenerationTranslator{t: t} |
86 | 205 | p := &imageGenerationProcessorUpstreamFilter{ |
87 | | - metrics: mm, |
88 | | - responseHeaders: map[string]string{":status": "500"}, |
| 206 | + translator: mt, |
| 207 | + metrics: mm, |
| 208 | + logger: slog.Default(), |
89 | 209 | } |
90 | | - res, err := p.ProcessResponseBody(t.Context(), &extprocv3.HttpBody{Body: []byte("err"), EndOfStream: true}) |
91 | | - require.NoError(t, err) |
92 | | - require.NotNil(t, res) |
93 | | - commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response |
94 | | - require.NotNil(t, commonRes.HeaderMutation) |
95 | | - require.NotNil(t, commonRes.BodyMutation) |
| 210 | + mt.retErr = errors.New("test error") |
| 211 | + _, err := p.ProcessResponseBody(t.Context(), &extprocv3.HttpBody{}) |
| 212 | + require.ErrorContains(t, err, "test error") |
96 | 213 | mm.RequireRequestFailure(t) |
| 214 | + require.Zero(t, mm.tokenUsageCount) |
97 | 215 | }) |
98 | | - |
99 | | - t.Run("200 end-of-stream records success and metadata", func(t *testing.T) { |
100 | | - inBody := &extprocv3.HttpBody{Body: []byte(`{"created":1,"data":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`), EndOfStream: true} |
| 216 | + t.Run("ok", func(t *testing.T) { |
| 217 | + inBody := &extprocv3.HttpBody{Body: []byte("some-body"), EndOfStream: true} |
| 218 | + expBodyMut := &extprocv3.BodyMutation{} |
| 219 | + expHeadMut := &extprocv3.HeaderMutation{} |
101 | 220 | mm := &mockImageGenerationMetrics{} |
| 221 | + mt := &mockImageGenerationTranslator{ |
| 222 | + t: t, expResponseBody: inBody, |
| 223 | + retBodyMutation: expBodyMut, retHeaderMutation: expHeadMut, |
| 224 | + retUsedToken: translator.LLMTokenUsage{OutputTokens: 123, InputTokens: 1}, |
| 225 | + } |
102 | 226 |
|
103 | | - celProgInt, err := llmcostcel.NewProgram("123") |
| 227 | + celProgInt, err := llmcostcel.NewProgram("54321") |
| 228 | + require.NoError(t, err) |
| 229 | + celProgUint, err := llmcostcel.NewProgram("uint(9999)") |
104 | 230 | require.NoError(t, err) |
105 | | - |
106 | 231 | p := &imageGenerationProcessorUpstreamFilter{ |
107 | | - translator: translator.NewImageGenerationOpenAIToOpenAITranslator("v1", "some_model"), |
| 232 | + translator: mt, |
| 233 | + logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), |
108 | 234 | metrics: mm, |
109 | | - logger: slog.Default(), |
110 | 235 | config: &processorConfig{ |
111 | | - metadataNamespace: "ai_gateway_llm_ns", |
| 236 | + metadataNamespace: "ai_gateway_llm_ns", |
| 237 | + modelNameHeaderKey: "x-aigw-model", |
112 | 238 | requestCosts: []processorConfigRequestCost{ |
| 239 | + {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}}, |
113 | 240 | {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}}, |
114 | | - {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeTotalToken, MetadataKey: "total_token_usage"}}, |
115 | | - {celProg: celProgInt, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"}}, |
| 241 | + { |
| 242 | + celProg: celProgInt, |
| 243 | + LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"}, |
| 244 | + }, |
| 245 | + { |
| 246 | + celProg: celProgUint, |
| 247 | + LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_uint"}, |
| 248 | + }, |
116 | 249 | }, |
117 | 250 | }, |
118 | | - backendName: "some_backend", |
119 | | - modelNameOverride: "some_model", |
| 251 | + requestHeaders: map[string]string{"x-aigw-model": "ai_gateway_llm"}, |
120 | 252 | responseHeaders: map[string]string{":status": "200"}, |
| 253 | + backendName: "some_backend", |
| 254 | + modelNameOverride: "ai_gateway_llm", |
121 | 255 | } |
122 | 256 | res, err := p.ProcessResponseBody(t.Context(), inBody) |
123 | 257 | require.NoError(t, err) |
124 | | - require.NotNil(t, res) |
| 258 | + commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response |
| 259 | + require.Equal(t, expBodyMut, commonRes.BodyMutation) |
| 260 | + require.Equal(t, expHeadMut, commonRes.HeaderMutation) |
125 | 261 | mm.RequireRequestSuccess(t) |
| 262 | + require.Equal(t, 124, mm.tokenUsageCount) // 1 input + 123 output |
126 | 263 |
|
127 | 264 | md := res.DynamicMetadata |
128 | 265 | require.NotNil(t, md) |
129 | | - require.Equal(t, float64(0), md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["input_token_usage"].GetNumberValue()) |
130 | | - require.Equal(t, float64(0), md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["total_token_usage"].GetNumberValue()) |
131 | | - require.Equal(t, float64(123), md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["cel_int"].GetNumberValue()) |
| 266 | + require.Equal(t, float64(123), md.Fields["ai_gateway_llm_ns"]. |
| 267 | + GetStructValue().Fields["output_token_usage"].GetNumberValue()) |
| 268 | + require.Equal(t, float64(1), md.Fields["ai_gateway_llm_ns"]. |
| 269 | + GetStructValue().Fields["input_token_usage"].GetNumberValue()) |
| 270 | + require.Equal(t, float64(54321), md.Fields["ai_gateway_llm_ns"]. |
| 271 | + GetStructValue().Fields["cel_int"].GetNumberValue()) |
| 272 | + require.Equal(t, float64(9999), md.Fields["ai_gateway_llm_ns"]. |
| 273 | + GetStructValue().Fields["cel_uint"].GetNumberValue()) |
| 274 | + require.Equal(t, "ai_gateway_llm", md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["model_name_override"].GetStringValue()) |
132 | 275 | require.Equal(t, "some_backend", md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["backend_name"].GetStringValue()) |
133 | | - require.Equal(t, "some_model", md.Fields["ai_gateway_llm_ns"].GetStructValue().Fields["model_name_override"].GetStringValue()) |
| 276 | + }) |
| 277 | + |
| 278 | + // Verify we record failure for non-2xx responses and do it exactly once (defer suppressed). |
| 279 | + t.Run("non-2xx status failure once", func(t *testing.T) { |
| 280 | + inBody := &extprocv3.HttpBody{Body: []byte("error-body"), EndOfStream: true} |
| 281 | + expHeadMut := &extprocv3.HeaderMutation{} |
| 282 | + expBodyMut := &extprocv3.BodyMutation{} |
| 283 | + mm := &mockImageGenerationMetrics{} |
| 284 | + mt := &mockImageGenerationTranslator{t: t, expResponseBody: inBody, retHeaderMutation: expHeadMut, retBodyMutation: expBodyMut} |
| 285 | + p := &imageGenerationProcessorUpstreamFilter{ |
| 286 | + translator: mt, |
| 287 | + metrics: mm, |
| 288 | + responseHeaders: map[string]string{":status": "500"}, |
| 289 | + logger: slog.Default(), |
| 290 | + } |
| 291 | + res, err := p.ProcessResponseBody(t.Context(), inBody) |
| 292 | + require.NoError(t, err) |
| 293 | + commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response |
| 294 | + require.Equal(t, expBodyMut, commonRes.BodyMutation) |
| 295 | + require.Equal(t, expHeadMut, commonRes.HeaderMutation) |
| 296 | + mm.RequireRequestFailure(t) |
134 | 297 | }) |
135 | 298 | } |
136 | 299 |
|
137 | 300 | func Test_imageGenerationProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { |
138 | 301 | t.Run("ok with auth handler and header mutator", func(t *testing.T) { |
139 | 302 | headers := map[string]string{":path": "/v1/images/generations", "x-model": "dall-e-3"} |
140 | 303 | mm := &mockImageGenerationMetrics{} |
| 304 | + body := &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModel("dall-e-3"), Prompt: "a cat"} |
| 305 | + mt := &mockImageGenerationTranslator{t: t, expRequestBody: body} |
141 | 306 | p := &imageGenerationProcessorUpstreamFilter{ |
142 | 307 | config: &processorConfig{modelNameHeaderKey: "x-model"}, |
143 | 308 | requestHeaders: headers, |
144 | 309 | logger: slog.Default(), |
145 | 310 | metrics: mm, |
146 | 311 | originalRequestBodyRaw: imageGenerationBodyFromModel(t, "dall-e-3"), |
147 | | - originalRequestBody: &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModel("dall-e-3"), Prompt: "a cat"}, |
| 312 | + originalRequestBody: body, |
148 | 313 | handler: &mockBackendAuthHandler{}, |
| 314 | + translator: mt, |
149 | 315 | } |
150 | 316 | resp, err := p.ProcessRequestHeaders(t.Context(), nil) |
151 | 317 | require.NoError(t, err) |
@@ -210,6 +376,52 @@ func TestImageGeneration_ParseBody(t *testing.T) { |
210 | 376 | }) |
211 | 377 | } |
212 | 378 |
|
| 379 | +// Mock translator for image generation tests |
| 380 | +type mockImageGenerationTranslator struct { |
| 381 | + t *testing.T |
| 382 | + expRequestBody *openaisdk.ImageGenerateParams |
| 383 | + expResponseBody *extprocv3.HttpBody |
| 384 | + expHeaders map[string]string |
| 385 | + expForceRequestBodyMutation bool |
| 386 | + retErr error |
| 387 | + retHeaderMutation *extprocv3.HeaderMutation |
| 388 | + retBodyMutation *extprocv3.BodyMutation |
| 389 | + retUsedToken translator.LLMTokenUsage |
| 390 | + retImageMetadata translator.ImageGenerationMetadata |
| 391 | +} |
| 392 | + |
| 393 | +func (m *mockImageGenerationTranslator) RequestBody(original []byte, req *openaisdk.ImageGenerateParams, forceBodyMutation bool) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, error) { |
| 394 | + if m.expRequestBody != nil { |
| 395 | + require.Equal(m.t, m.expRequestBody, req) |
| 396 | + } |
| 397 | + if m.expForceRequestBodyMutation { |
| 398 | + require.True(m.t, forceBodyMutation) |
| 399 | + } |
| 400 | + return m.retHeaderMutation, m.retBodyMutation, m.retErr |
| 401 | +} |
| 402 | + |
| 403 | +func (m *mockImageGenerationTranslator) ResponseHeaders(headers map[string]string) (*extprocv3.HeaderMutation, error) { |
| 404 | + if m.expHeaders != nil { |
| 405 | + for k, v := range m.expHeaders { |
| 406 | + require.Equal(m.t, v, headers[k]) |
| 407 | + } |
| 408 | + } |
| 409 | + return m.retHeaderMutation, m.retErr |
| 410 | +} |
| 411 | + |
| 412 | +func (m *mockImageGenerationTranslator) ResponseBody(headers map[string]string, body io.Reader, endOfStream bool) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, translator.LLMTokenUsage, translator.ImageGenerationMetadata, error) { |
| 413 | + if m.expResponseBody != nil { |
| 414 | + bodyBytes, _ := io.ReadAll(body) |
| 415 | + require.Equal(m.t, m.expResponseBody.Body, bodyBytes) |
| 416 | + require.Equal(m.t, m.expResponseBody.EndOfStream, endOfStream) |
| 417 | + } |
| 418 | + return m.retHeaderMutation, m.retBodyMutation, m.retUsedToken, m.retImageMetadata, m.retErr |
| 419 | +} |
| 420 | + |
| 421 | +func (m *mockImageGenerationTranslator) ResponseError(headers map[string]string, body io.Reader) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, error) { |
| 422 | + return m.retHeaderMutation, m.retBodyMutation, m.retErr |
| 423 | +} |
| 424 | + |
213 | 425 | // imageGenerationBodyFromModel returns a minimal valid image generation request for tests. |
214 | 426 | func imageGenerationBodyFromModel(t *testing.T, model string) []byte { |
215 | 427 | t.Helper() |
|
0 commit comments