diff --git a/internal/extproc/mocks_test.go b/internal/extproc/mocks_test.go index 053676287..167d5e64f 100644 --- a/internal/extproc/mocks_test.go +++ b/internal/extproc/mocks_test.go @@ -305,3 +305,13 @@ type mockBackendAuthHandler struct{} func (m *mockBackendAuthHandler) Do(context.Context, map[string]string, []byte) ([]internalapi.Header, error) { return []internalapi.Header{{"foo", "mock-auth-handler"}}, nil } + +// mockBackendAuthHandlerError implements [filterapi.BackendAuthHandler] for testing auth errors. +type mockBackendAuthHandlerError struct { + err error +} + +// Do implements [filterapi.BackendAuthHandler.Do]. +func (m *mockBackendAuthHandlerError) Do(context.Context, map[string]string, []byte) ([]internalapi.Header, error) { + return nil, m.err +} diff --git a/internal/extproc/processor_impl.go b/internal/extproc/processor_impl.go index 0c9e85993..e45dd323b 100644 --- a/internal/extproc/processor_impl.go +++ b/internal/extproc/processor_impl.go @@ -15,6 +15,8 @@ import ( corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/bodymutator" @@ -165,7 +167,19 @@ func (r *routerProcessor[ReqT, RespT, RespChunkT, EndpointSpecT]) ProcessRespons func (r *routerProcessor[ReqT, RespT, RespChunkT, EndpointSpecT]) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) { originalModel, body, stream, mutatedOriginalBody, err := r.eh.ParseBody(rawBody.Body, len(r.config.RequestCosts) > 0) if err != nil { - return nil, fmt.Errorf("failed to parse request body: %w", err) + // Return an immediate error response instead of nil + errorMsg := fmt.Sprintf("failed to parse request body: %v", err) + + // 400 and 422 might be both reasonable here + return &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_ImmediateResponse{ + ImmediateResponse: &extprocv3.ImmediateResponse{ + Status: &typev3.HttpStatus{Code: typev3.StatusCode(400)}, + Body: []byte(errorMsg), + GrpcStatus: &extprocv3.GrpcStatus{Status: uint32(codes.InvalidArgument)}, + }, + }, + }, fmt.Errorf("failed to parse request body: %w", err) } if mutatedOriginalBody != nil { r.originalRequestBodyRaw = mutatedOriginalBody @@ -243,7 +257,17 @@ func (u *upstreamProcessor[ReqT, RespT, RespChunkT, EndpointSpecT]) ProcessReque forceBodyMutation := u.onRetry() || u.parent.forceBodyMutation newHeaders, newBody, err := u.translator.RequestBody(u.parent.originalRequestBodyRaw, u.parent.originalRequestBody, forceBodyMutation) if err != nil { - return nil, fmt.Errorf("failed to transform request: %w", err) + // Return an immediate error response instead of nil + errorMsg := fmt.Sprintf("failed to transform request: %v", err) + return &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_ImmediateResponse{ + ImmediateResponse: &extprocv3.ImmediateResponse{ + Status: &typev3.HttpStatus{Code: typev3.StatusCode(422)}, + Body: []byte(errorMsg), + GrpcStatus: &extprocv3.GrpcStatus{Status: uint32(codes.InvalidArgument)}, + }, + }, + }, fmt.Errorf("failed to transform request: %w", err) } headerMutation, bodyMutation := mutationsFromTranslationResult(newHeaders, newBody) diff --git a/internal/extproc/processor_impl_test.go b/internal/extproc/processor_impl_test.go index 66513ccdb..96dc988e3 100644 --- a/internal/extproc/processor_impl_test.go +++ b/internal/extproc/processor_impl_test.go @@ -16,6 +16,7 @@ import ( corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/propagation" "google.golang.org/protobuf/types/known/structpb" @@ -90,8 +91,16 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) { tracer: tracing.NoopTracer[openai.ChatCompletionRequest, openai.ChatCompletionResponse, openai.ChatCompletionResponseChunk]{}, config: &filterapi.RuntimeConfig{}, } - _, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: []byte("nonjson")}) - require.ErrorContains(t, err, "invalid character 'o' in literal null") + resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: []byte("nonjson")}) + require.Error(t, err, "Should return an error along with immediate response") + require.Contains(t, err.Error(), "failed to parse request body") + require.NotNil(t, resp, "Response should not be nil") + + // Verify it's an immediate response with the correct error message + immediateResp, ok := resp.Response.(*extprocv3.ProcessingResponse_ImmediateResponse) + require.True(t, ok, "Response should be an immediate response") + require.Equal(t, typev3.StatusCode(400), immediateResp.ImmediateResponse.Status.Code) + require.Contains(t, string(immediateResp.ImmediateResponse.Body), "failed to parse request body") }) t.Run("ok", func(t *testing.T) { @@ -442,14 +451,58 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing metrics: mm, translator: tr, } - _, err := p.ProcessRequestHeaders(t.Context(), nil) - require.ErrorContains(t, err, "failed to transform request: test error") + resp, err := p.ProcessRequestHeaders(t.Context(), nil) + require.Error(t, err, "Should return an error along with immediate response") + require.Contains(t, err.Error(), "failed to transform request: test error") + require.NotNil(t, resp, "Response should not be nil") + + // Verify it's an immediate response with the correct error message + immediateResp, ok := resp.Response.(*extprocv3.ProcessingResponse_ImmediateResponse) + require.True(t, ok, "Response should be an immediate response") + require.Equal(t, typev3.StatusCode(422), immediateResp.ImmediateResponse.Status.Code) + require.Contains(t, string(immediateResp.ImmediateResponse.Body), "failed to transform request") + require.Contains(t, string(immediateResp.ImmediateResponse.Body), "test error") + mm.RequireRequestFailure(t) require.Zero(t, mm.inputTokenCount) // Verify models were set even though processing failed require.Equal(t, "some-model", mm.originalModel) require.Equal(t, "some-model", mm.requestModel) }) + t.Run("auth handler error", func(t *testing.T) { + headers := map[string]string{":path": "/foo", internalapi.ModelNameHeaderKeyDefault: "some-model"} + someBody := bodyFromModel(t, "some-model", tc.stream, nil) + var body openai.ChatCompletionRequest + require.NoError(t, json.Unmarshal(someBody, &body)) + tr := mockTranslator{t: t, expRequestBody: &body} + mm := &mockMetrics{} + // Create a mock auth handler that returns an error + authHandler := &mockBackendAuthHandlerError{err: errors.New("authentication failed")} + p := &chatCompletionProcessorUpstreamFilter{ + parent: &chatCompletionProcessorRouterFilter{ + config: &filterapi.RuntimeConfig{}, + logger: slog.Default(), + originalRequestBodyRaw: someBody, + originalRequestBody: &body, + originalModel: "some-model", + stream: tc.stream, + }, + requestHeaders: headers, + metrics: mm, + translator: tr, + handler: authHandler, + } + resp, err := p.ProcessRequestHeaders(t.Context(), nil) + require.Error(t, err, "Should return an error") + require.Contains(t, err.Error(), "failed to do auth request: authentication failed") + require.Nil(t, resp, "Response should be nil when auth request fails") + + mm.RequireRequestFailure(t) + require.Zero(t, mm.inputTokenCount) + // Verify models were set even though authentication failed + require.Equal(t, "some-model", mm.originalModel) + require.Equal(t, "some-model", mm.requestModel) + }) t.Run("ok", func(t *testing.T) { someBody := bodyFromModel(t, "some-model", tc.stream, nil) headers := map[string]string{":path": "/foo", internalapi.ModelNameHeaderKeyDefault: "some-model"}