Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions internal/extproc/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,13 @@ type mockBackendAuthHandler struct{}
func (m *mockBackendAuthHandler) Do(context.Context, map[string]string, []byte) ([]internalapi.Header, error) {
return []internalapi.Header{{"foo", "mock-auth-handler"}}, nil
}

// mockBackendAuthHandlerError implements [filterapi.BackendAuthHandler] for testing auth errors.
type mockBackendAuthHandlerError struct {
err error
}

// Do implements [filterapi.BackendAuthHandler.Do].
func (m *mockBackendAuthHandlerError) Do(context.Context, map[string]string, []byte) ([]internalapi.Header, error) {
return nil, m.err
}
40 changes: 37 additions & 3 deletions internal/extproc/processor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3"
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/types/known/structpb"

"github.com/envoyproxy/ai-gateway/internal/bodymutator"
Expand Down Expand Up @@ -165,7 +167,19 @@ func (r *routerProcessor[ReqT, RespT, RespChunkT, EndpointSpecT]) ProcessRespons
func (r *routerProcessor[ReqT, RespT, RespChunkT, EndpointSpecT]) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
originalModel, body, stream, mutatedOriginalBody, err := r.eh.ParseBody(rawBody.Body, len(r.config.RequestCosts) > 0)
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %w", err)
// Return an immediate error response instead of nil
errorMsg := fmt.Sprintf("failed to parse request body: %v", err)

// 400 and 422 might be both reasonable here
return &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extprocv3.ImmediateResponse{
Status: &typev3.HttpStatus{Code: typev3.StatusCode(400)},
Body: []byte(errorMsg),
GrpcStatus: &extprocv3.GrpcStatus{Status: uint32(codes.InvalidArgument)},
},
},
}, fmt.Errorf("failed to parse request body: %w", err)
}
if mutatedOriginalBody != nil {
r.originalRequestBodyRaw = mutatedOriginalBody
Expand Down Expand Up @@ -243,7 +257,17 @@ func (u *upstreamProcessor[ReqT, RespT, RespChunkT, EndpointSpecT]) ProcessReque
forceBodyMutation := u.onRetry() || u.parent.forceBodyMutation
newHeaders, newBody, err := u.translator.RequestBody(u.parent.originalRequestBodyRaw, u.parent.originalRequestBody, forceBodyMutation)
if err != nil {
return nil, fmt.Errorf("failed to transform request: %w", err)
// Return an immediate error response instead of nil
errorMsg := fmt.Sprintf("failed to transform request: %v", err)
return &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extprocv3.ImmediateResponse{
Status: &typev3.HttpStatus{Code: typev3.StatusCode(422)},
Body: []byte(errorMsg),
GrpcStatus: &extprocv3.GrpcStatus{Status: uint32(codes.InvalidArgument)},
},
},
}, fmt.Errorf("failed to transform request: %w", err)
}

headerMutation, bodyMutation := mutationsFromTranslationResult(newHeaders, newBody)
Expand Down Expand Up @@ -280,7 +304,17 @@ func (u *upstreamProcessor[ReqT, RespT, RespChunkT, EndpointSpecT]) ProcessReque
var hdrs []internalapi.Header
hdrs, err = h.Do(ctx, u.requestHeaders, bodyMutation.GetBody())
if err != nil {
return nil, fmt.Errorf("failed to do auth request: %w", err)
// Return an immediate error response instead of nil
errorMsg := fmt.Sprintf("failed to do auth request: %v", err)
return &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extprocv3.ImmediateResponse{
Status: &typev3.HttpStatus{Code: typev3.StatusCode(401)},
Body: []byte(errorMsg),
GrpcStatus: &extprocv3.GrpcStatus{Status: uint32(codes.Unauthenticated)},
},
},
}, fmt.Errorf("failed to do auth request: %w", err)
}
for _, h := range hdrs {
headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{
Expand Down
68 changes: 64 additions & 4 deletions internal/extproc/processor_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3"
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/propagation"
"google.golang.org/protobuf/types/known/structpb"
Expand Down Expand Up @@ -90,8 +91,16 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) {
tracer: tracing.NoopTracer[openai.ChatCompletionRequest, openai.ChatCompletionResponse, openai.ChatCompletionResponseChunk]{},
config: &filterapi.RuntimeConfig{},
}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: []byte("nonjson")})
require.ErrorContains(t, err, "invalid character 'o' in literal null")
resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: []byte("nonjson")})
require.Error(t, err, "Should return an error along with immediate response")
require.Contains(t, err.Error(), "failed to parse request body")
require.NotNil(t, resp, "Response should not be nil")

// Verify it's an immediate response with the correct error message
immediateResp, ok := resp.Response.(*extprocv3.ProcessingResponse_ImmediateResponse)
require.True(t, ok, "Response should be an immediate response")
require.Equal(t, typev3.StatusCode(400), immediateResp.ImmediateResponse.Status.Code)
require.Contains(t, string(immediateResp.ImmediateResponse.Body), "failed to parse request body")
})

t.Run("ok", func(t *testing.T) {
Expand Down Expand Up @@ -442,14 +451,65 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing
metrics: mm,
translator: tr,
}
_, err := p.ProcessRequestHeaders(t.Context(), nil)
require.ErrorContains(t, err, "failed to transform request: test error")
resp, err := p.ProcessRequestHeaders(t.Context(), nil)
require.Error(t, err, "Should return an error along with immediate response")
require.Contains(t, err.Error(), "failed to transform request: test error")
require.NotNil(t, resp, "Response should not be nil")

// Verify it's an immediate response with the correct error message
immediateResp, ok := resp.Response.(*extprocv3.ProcessingResponse_ImmediateResponse)
require.True(t, ok, "Response should be an immediate response")
require.Equal(t, typev3.StatusCode(422), immediateResp.ImmediateResponse.Status.Code)
require.Contains(t, string(immediateResp.ImmediateResponse.Body), "failed to transform request")
require.Contains(t, string(immediateResp.ImmediateResponse.Body), "test error")

mm.RequireRequestFailure(t)
require.Zero(t, mm.inputTokenCount)
// Verify models were set even though processing failed
require.Equal(t, "some-model", mm.originalModel)
require.Equal(t, "some-model", mm.requestModel)
})
t.Run("auth handler error", func(t *testing.T) {
headers := map[string]string{":path": "/foo", internalapi.ModelNameHeaderKeyDefault: "some-model"}
someBody := bodyFromModel(t, "some-model", tc.stream, nil)
var body openai.ChatCompletionRequest
require.NoError(t, json.Unmarshal(someBody, &body))
tr := mockTranslator{t: t, expRequestBody: &body}
mm := &mockMetrics{}
// Create a mock auth handler that returns an error
authHandler := &mockBackendAuthHandlerError{err: errors.New("authentication failed")}
p := &chatCompletionProcessorUpstreamFilter{
parent: &chatCompletionProcessorRouterFilter{
config: &filterapi.RuntimeConfig{},
logger: slog.Default(),
originalRequestBodyRaw: someBody,
originalRequestBody: &body,
originalModel: "some-model",
stream: tc.stream,
},
requestHeaders: headers,
metrics: mm,
translator: tr,
handler: authHandler,
}
resp, err := p.ProcessRequestHeaders(t.Context(), nil)
require.Error(t, err, "Should return an error along with immediate response")
require.Contains(t, err.Error(), "failed to do auth request: authentication failed")
require.NotNil(t, resp, "Response should not be nil")

// Verify it's an immediate response with the correct error message
immediateResp, ok := resp.Response.(*extprocv3.ProcessingResponse_ImmediateResponse)
require.True(t, ok, "Response should be an immediate response")
require.Equal(t, typev3.StatusCode(401), immediateResp.ImmediateResponse.Status.Code)
require.Contains(t, string(immediateResp.ImmediateResponse.Body), "failed to do auth request")
require.Contains(t, string(immediateResp.ImmediateResponse.Body), "authentication failed")

mm.RequireRequestFailure(t)
require.Zero(t, mm.inputTokenCount)
// Verify models were set even though authentication failed
require.Equal(t, "some-model", mm.originalModel)
require.Equal(t, "some-model", mm.requestModel)
})
t.Run("ok", func(t *testing.T) {
someBody := bodyFromModel(t, "some-model", tc.stream, nil)
headers := map[string]string{":path": "/foo", internalapi.ModelNameHeaderKeyDefault: "some-model"}
Expand Down