Skip to content

Commit 29905d5

Browse files
yuzisunaabchoo
andauthored
fix: use internal id for routerProcessorsPerReqID map (#1344)
**Description** The fix ensures that even if multiple requests arrive with identical x-request-id headers, they will have unique internal IDs in the `routerProcessorsPerReqID` map, preventing processor overwrites and incorrect cleanup. **Related Issues/PRs (if applicable)** Fixes #1340 **Special notes for reviewers (if applicable)** - Created internal request ID header constant: Added internalReqIDHeader = "x-ai-eg-internal-req-id" - Modified request ID generation logic: - Router filter: Creates unique internal request ID by appending UUID suffix: `originalReqID + "-" + uuid.NewString()` - Upstream filter: Uses the internal request ID passed from router filter via header --------- Signed-off-by: Dan Sun <[email protected]> Co-authored-by: achoo30 <[email protected]>
1 parent b678b35 commit 29905d5

File tree

2 files changed

+95
-21
lines changed

2 files changed

+95
-21
lines changed

internal/extproc/server.go

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2121
typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
2222
"github.com/google/cel-go/cel"
23+
"github.com/google/uuid"
2324
"google.golang.org/grpc/codes"
2425
"google.golang.org/grpc/health/grpc_health_v1"
2526
"google.golang.org/grpc/status"
@@ -45,6 +46,7 @@ type Server struct {
4546
processorFactories map[string]ProcessorFactory
4647
routerProcessorsPerReqID map[string]Processor
4748
routerProcessorsPerReqIDMutex sync.RWMutex
49+
uuidFn func() string
4850
}
4951

5052
// NewServer creates a new external processor server.
@@ -54,6 +56,7 @@ func NewServer(logger *slog.Logger, tracing tracing.Tracing) (*Server, error) {
5456
tracing: tracing,
5557
processorFactories: make(map[string]ProcessorFactory),
5658
routerProcessorsPerReqID: make(map[string]Processor),
59+
uuidFn: uuid.NewString,
5760
}
5861
return srv, nil
5962
}
@@ -131,6 +134,10 @@ func (s *Server) processorForPath(requestHeaders map[string]string, isUpstreamFi
131134
// This is used in the upstream filter level to determine the original path of the request on retry.
132135
const originalPathHeader = "x-ai-eg-original-path"
133136

137+
// internalReqIDHeader is the header used to pass the unique internal request ID to the upstream filter.
138+
// This ensures that the upstream filter uses the same unique ID as the router filter to avoid race conditions.
139+
const internalReqIDHeader = "x-ai-eg-internal-req-id"
140+
134141
// Process implements [extprocv3.ExternalProcessorServer].
135142
func (s *Server) Process(stream extprocv3.ExternalProcessor_ProcessServer) error {
136143
s.logger.Debug("handling a new stream", slog.Any("config_uuid", s.config.uuid))
@@ -145,13 +152,14 @@ func (s *Server) Process(stream extprocv3.ExternalProcessor_ProcessServer) error
145152
// to pass the request through without any processing as there would be nothing to process from AI Gateway's perspective.
146153
var p Processor = passThroughProcessor{}
147154
var isUpstreamFilter bool
148-
var reqID string
155+
var internalReqID string
156+
var originalReqID string
149157
var logger *slog.Logger
150158
defer func() {
151159
if !isUpstreamFilter {
152160
s.routerProcessorsPerReqIDMutex.Lock()
153161
defer s.routerProcessorsPerReqIDMutex.Unlock()
154-
delete(s.routerProcessorsPerReqID, reqID)
162+
delete(s.routerProcessorsPerReqID, internalReqID)
155163
}
156164
}()
157165

@@ -177,9 +185,21 @@ func (s *Server) Process(stream extprocv3.ExternalProcessor_ProcessServer) error
177185
// request, and the processor will be instantiated only once.
178186
if headers := req.GetRequestHeaders().GetHeaders(); headers != nil {
179187
headersMap := headersToMap(headers)
180-
reqID = headersMap["x-request-id"]
188+
originalReqID = headersMap["x-request-id"]
181189
// Assume that when attributes are set, this stream is for the upstream filter level.
182190
isUpstreamFilter = req.GetAttributes() != nil
191+
192+
if isUpstreamFilter {
193+
// For upstream filter, use the internal request ID passed from the router filter
194+
internalReqID = headersMap[internalReqIDHeader]
195+
if internalReqID == "" {
196+
return status.Errorf(codes.Internal, "missing internal request ID header from router filter")
197+
}
198+
} else {
199+
// For router filter, create a unique internal request ID to avoid race conditions
200+
// with duplicate x-request-id values by appending a UUID suffix to the original request ID
201+
internalReqID = originalReqID + "-" + s.uuidFn()
202+
}
183203
p, err = s.processorForPath(headersMap, isUpstreamFilter)
184204
if err != nil {
185205
if errors.Is(err, errNoProcessor) {
@@ -200,22 +220,22 @@ func (s *Server) Process(stream extprocv3.ExternalProcessor_ProcessServer) error
200220
}
201221
_, isEndpoinPicker := headersMap[internalapi.EndpointPickerHeaderKey]
202222
if isUpstreamFilter {
203-
if err = s.setBackend(ctx, p, reqID, isEndpoinPicker, req); err != nil {
223+
if err = s.setBackend(ctx, p, internalReqID, isEndpoinPicker, req); err != nil {
204224
s.logger.Error("error processing request message", slog.String("error", err.Error()))
205225
return status.Errorf(codes.Unknown, "error processing request message: %v", err)
206226
}
207227
} else {
208228
s.routerProcessorsPerReqIDMutex.Lock()
209-
s.routerProcessorsPerReqID[reqID] = p
229+
s.routerProcessorsPerReqID[internalReqID] = p
210230
s.routerProcessorsPerReqIDMutex.Unlock()
211231
}
212232
}
213233
if logger == nil {
214-
logger = s.logger.With("request_id", reqID, "is_upstream_filter", isUpstreamFilter)
234+
logger = s.logger.With("request_id", originalReqID, "is_upstream_filter", isUpstreamFilter)
215235
}
216236

217237
// At this point, p is guaranteed to be a valid processor either from the concrete processor or the passThroughProcessor.
218-
resp, err := s.processMsg(ctx, logger, p, req)
238+
resp, err := s.processMsg(ctx, logger, p, req, internalReqID, isUpstreamFilter)
219239
if err != nil {
220240
s.logger.Error("error processing request message", slog.String("error", err.Error()))
221241
return status.Errorf(codes.Unknown, "error processing request message: %v", err)
@@ -227,7 +247,7 @@ func (s *Server) Process(stream extprocv3.ExternalProcessor_ProcessServer) error
227247
}
228248
}
229249

230-
func (s *Server) processMsg(ctx context.Context, l *slog.Logger, p Processor, req *extprocv3.ProcessingRequest) (*extprocv3.ProcessingResponse, error) {
250+
func (s *Server) processMsg(ctx context.Context, l *slog.Logger, p Processor, req *extprocv3.ProcessingRequest, internalReqID string, isUpstreamFilter bool) (*extprocv3.ProcessingResponse, error) {
231251
switch value := req.Request.(type) {
232252
case *extprocv3.ProcessingRequest_RequestHeaders:
233253
requestHdrs := req.GetRequestHeaders().Headers
@@ -240,6 +260,35 @@ func (s *Server) processMsg(ctx context.Context, l *slog.Logger, p Processor, re
240260
if err != nil {
241261
return nil, fmt.Errorf("cannot process request headers: %w", err)
242262
}
263+
264+
// For router filter, inject the internal request ID header so upstream filter can use it
265+
if !isUpstreamFilter && resp != nil {
266+
if requestHeaders, ok := resp.Response.(*extprocv3.ProcessingResponse_RequestHeaders); ok {
267+
// Ensure we have header mutation to add the internal request ID
268+
if requestHeaders.RequestHeaders == nil {
269+
requestHeaders.RequestHeaders = &extprocv3.HeadersResponse{}
270+
}
271+
if requestHeaders.RequestHeaders.Response == nil {
272+
requestHeaders.RequestHeaders.Response = &extprocv3.CommonResponse{}
273+
}
274+
if requestHeaders.RequestHeaders.Response.HeaderMutation == nil {
275+
requestHeaders.RequestHeaders.Response.HeaderMutation = &extprocv3.HeaderMutation{}
276+
}
277+
278+
// Add the internal request ID header
279+
internalReqIDHeaderValue := &corev3.HeaderValueOption{
280+
Header: &corev3.HeaderValue{
281+
Key: internalReqIDHeader,
282+
RawValue: []byte(internalReqID),
283+
},
284+
}
285+
requestHeaders.RequestHeaders.Response.HeaderMutation.SetHeaders = append(
286+
requestHeaders.RequestHeaders.Response.HeaderMutation.SetHeaders,
287+
internalReqIDHeaderValue,
288+
)
289+
}
290+
}
291+
243292
l.Debug("request headers processed", slog.Any("response", resp))
244293
return resp, nil
245294
case *extprocv3.ProcessingRequest_RequestBody:
@@ -279,7 +328,7 @@ func (s *Server) processMsg(ctx context.Context, l *slog.Logger, p Processor, re
279328

280329
// setBackend retrieves the backend from the request attributes and sets it in the processor. This is only called
281330
// if the processor is an upstream filter.
282-
func (s *Server) setBackend(ctx context.Context, p Processor, reqID string, isEndpointPicker bool, req *extprocv3.ProcessingRequest) error {
331+
func (s *Server) setBackend(ctx context.Context, p Processor, internalReqID string, isEndpointPicker bool, req *extprocv3.ProcessingRequest) error {
283332
attributes := req.GetAttributes()["envoy.filters.http.ext_proc"]
284333
if attributes == nil || len(attributes.Fields) == 0 { // coverage-ignore
285334
return status.Error(codes.Internal, "missing attributes in request")
@@ -317,10 +366,10 @@ func (s *Server) setBackend(ctx context.Context, p Processor, reqID string, isEn
317366

318367
s.routerProcessorsPerReqIDMutex.RLock()
319368
defer s.routerProcessorsPerReqIDMutex.RUnlock()
320-
routerProcessor, ok := s.routerProcessorsPerReqID[reqID]
369+
routerProcessor, ok := s.routerProcessorsPerReqID[internalReqID]
321370
if !ok {
322371
return status.Errorf(codes.Internal, "no router processor found, request_id=%s, backend=%s",
323-
reqID, backendName.GetStringValue())
372+
internalReqID, backendName.GetStringValue())
324373
}
325374

326375
if err := p.SetBackend(ctx, backend.b, backend.handler, routerProcessor); err != nil {

internal/extproc/server_test.go

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func TestServer_List(t *testing.T) {
120120
func TestServer_processMsg(t *testing.T) {
121121
t.Run("unknown request type", func(t *testing.T) {
122122
s, p := requireNewServerWithMockProcessor(t)
123-
_, err := s.processMsg(t.Context(), slog.Default(), p, &extprocv3.ProcessingRequest{})
123+
_, err := s.processMsg(t.Context(), slog.Default(), p, &extprocv3.ProcessingRequest{}, "test-req-id", false)
124124
require.ErrorContains(t, err, "unknown request type")
125125
})
126126
t.Run("request headers", func(t *testing.T) {
@@ -134,7 +134,7 @@ func TestServer_processMsg(t *testing.T) {
134134
req := &extprocv3.ProcessingRequest{
135135
Request: &extprocv3.ProcessingRequest_RequestHeaders{RequestHeaders: &extprocv3.HttpHeaders{Headers: hm}},
136136
}
137-
resp, err := s.processMsg(t.Context(), slog.Default(), p, req)
137+
resp, err := s.processMsg(t.Context(), slog.Default(), p, req, "test-req-id", false)
138138
require.NoError(t, err)
139139
require.NotNil(t, resp)
140140
require.Equal(t, expResponse, resp)
@@ -150,7 +150,7 @@ func TestServer_processMsg(t *testing.T) {
150150
req := &extprocv3.ProcessingRequest{
151151
Request: &extprocv3.ProcessingRequest_RequestBody{RequestBody: reqBody},
152152
}
153-
resp, err := s.processMsg(t.Context(), slog.Default(), p, req)
153+
resp, err := s.processMsg(t.Context(), slog.Default(), p, req, "test-req-id", false)
154154
require.NoError(t, err)
155155
require.NotNil(t, resp)
156156
require.Equal(t, expResponse, resp)
@@ -166,7 +166,7 @@ func TestServer_processMsg(t *testing.T) {
166166
req := &extprocv3.ProcessingRequest{
167167
Request: &extprocv3.ProcessingRequest_ResponseHeaders{ResponseHeaders: &extprocv3.HttpHeaders{Headers: hm}},
168168
}
169-
resp, err := s.processMsg(t.Context(), slog.Default(), p, req)
169+
resp, err := s.processMsg(t.Context(), slog.Default(), p, req, "test-req-id", false)
170170
require.NoError(t, err)
171171
require.NotNil(t, resp)
172172
require.Equal(t, expResponse, resp)
@@ -182,7 +182,7 @@ func TestServer_processMsg(t *testing.T) {
182182
req := &extprocv3.ProcessingRequest{
183183
Request: &extprocv3.ProcessingRequest_ResponseHeaders{ResponseHeaders: &extprocv3.HttpHeaders{Headers: hm}},
184184
}
185-
resp, err := s.processMsg(t.Context(), slog.Default(), p, req)
185+
resp, err := s.processMsg(t.Context(), slog.Default(), p, req, "test-req-id", false)
186186
require.NoError(t, err)
187187
require.NotNil(t, resp)
188188
require.Equal(t, expResponse, resp)
@@ -198,7 +198,7 @@ func TestServer_processMsg(t *testing.T) {
198198
req := &extprocv3.ProcessingRequest{
199199
Request: &extprocv3.ProcessingRequest_ResponseBody{ResponseBody: reqBody},
200200
}
201-
resp, err := s.processMsg(t.Context(), slog.Default(), p, req)
201+
resp, err := s.processMsg(t.Context(), slog.Default(), p, req, "test-req-id", false)
202202
require.NoError(t, err)
203203
require.NotNil(t, resp)
204204
require.Equal(t, expResponse, resp)
@@ -236,7 +236,7 @@ func TestServer_Process(t *testing.T) {
236236
t.Run("upstream filter", func(t *testing.T) {
237237
s, p := requireNewServerWithMockProcessor(t)
238238

239-
hm := &corev3.HeaderMap{Headers: []*corev3.HeaderValue{{Key: originalPathHeader, Value: "/"}, {Key: "foo", Value: "bar"}}}
239+
hm := &corev3.HeaderMap{Headers: []*corev3.HeaderValue{{Key: originalPathHeader, Value: "/"}, {Key: internalReqIDHeader, Value: "test-req-id-123"}, {Key: "foo", Value: "bar"}}}
240240
p.t = t
241241
p.expHeaderMap = hm
242242
req := &extprocv3.ProcessingRequest{
@@ -351,7 +351,7 @@ func TestServer_ProcessorSelection(t *testing.T) {
351351
s.Register("/two", func(*processorConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) {
352352
return &mockProcessor{
353353
t: t,
354-
expHeaderMap: &corev3.HeaderMap{Headers: []*corev3.HeaderValue{{Key: ":path", Value: "/two"}}},
354+
expHeaderMap: &corev3.HeaderMap{Headers: []*corev3.HeaderValue{{Key: ":path", Value: "/two"}, {Key: "x-request-id", Value: "original-req-id"}}},
355355
retProcessingResponse: &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{}},
356356
}, nil
357357
})
@@ -388,11 +388,36 @@ func TestServer_ProcessorSelection(t *testing.T) {
388388
req := &extprocv3.ProcessingRequest{
389389
Request: &extprocv3.ProcessingRequest_RequestHeaders{
390390
RequestHeaders: &extprocv3.HttpHeaders{
391-
Headers: &corev3.HeaderMap{Headers: []*corev3.HeaderValue{{Key: ":path", Value: "/two"}}},
391+
Headers: &corev3.HeaderMap{Headers: []*corev3.HeaderValue{
392+
{Key: ":path", Value: "/two"},
393+
{Key: "x-request-id", Value: "original-req-id"},
394+
}},
392395
},
393396
},
394397
}
395-
expResponse := &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{}}
398+
expResponse := &extprocv3.ProcessingResponse{
399+
Response: &extprocv3.ProcessingResponse_RequestHeaders{
400+
RequestHeaders: &extprocv3.HeadersResponse{
401+
Response: &extprocv3.CommonResponse{
402+
HeaderMutation: &extprocv3.HeaderMutation{
403+
SetHeaders: []*corev3.HeaderValueOption{
404+
{
405+
Header: &corev3.HeaderValue{
406+
Key: internalReqIDHeader,
407+
RawValue: []byte("original-req-id-test-internal-req-id"),
408+
},
409+
},
410+
},
411+
},
412+
},
413+
},
414+
},
415+
}
416+
417+
s.uuidFn = func() string {
418+
return "test-internal-req-id"
419+
}
420+
396421
ms := &mockExternalProcessingStream{t: t, ctx: ctx, retRecv: req, expResponseOnSend: expResponse}
397422

398423
err = s.Process(ms)

0 commit comments

Comments
 (0)