Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 10 additions & 86 deletions internal/extproc/embeddings_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,14 +496,13 @@ func TestEmbeddingsProcessorRouterFilter_ProcessResponseHeaders_ProcessResponseB
}

func TestEmbeddingsProcessorUpstreamFilter_ProcessRequestHeaders_WithHeaderMutations(t *testing.T) {
const testModelKey = "x-ai-gateway-model-key"
t.Run("header mutations applied correctly", func(t *testing.T) {
headers := map[string]string{
":path": "/v1/embeddings",
testModelKey: "some-model",
"authorization": "bearer token123",
"x-api-key": "secret-key",
"x-custom": "custom-value",
":path": "/v1/embeddings",
internalapi.ModelNameHeaderKeyDefault: "some-model",
"authorization": "bearer token123",
"x-api-key": "secret-key",
"x-custom": "custom-value",
}
someBody := embeddingBodyFromModel(t, "some-model")
var body openai.EmbeddingRequest
Expand Down Expand Up @@ -557,86 +556,11 @@ func TestEmbeddingsProcessorUpstreamFilter_ProcessRequestHeaders_WithHeaderMutat
require.Equal(t, "custom-value", headers["x-custom"])
})

t.Run("header mutations restored on retry", func(t *testing.T) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this simply re-tests the headermutator internal so it's overlapping with the tests there, hence removed it.

headers := map[string]string{
":path": "/v1/embeddings",
testModelKey: "some-model",
// "x-custom" is not present in current headers, so it can be restored.
"x-new-header": "new-value", // Already set from previous mutation.
}
someBody := embeddingBodyFromModel(t, "some-model")
var body openai.EmbeddingRequest
require.NoError(t, json.Unmarshal(someBody, &body))

// Create header mutations that don't remove x-custom (so it can be restored).
headerMutations := &filterapi.HTTPHeaderMutation{
Remove: []string{"authorization", "x-api-key"},
Set: []filterapi.HTTPHeader{{Name: "x-new-header", Value: "updated-value"}},
}

mt := &mockEmbeddingTranslator{t: t, expRequestBody: &body}
mm := &mockEmbeddingsMetrics{}
p := &embeddingsProcessorUpstreamFilter{
config: &processorConfig{},
requestHeaders: headers,
logger: slog.Default(),
metrics: mm,
translator: mt,
originalRequestBodyRaw: someBody,
originalRequestBody: &body,
handler: &mockBackendAuthHandler{},
onRetry: true, // This is a retry request.
}

// Use the same headers map as the original headers (this simulates the router filter's requestHeaders).
originalHeaders := map[string]string{
":path": "/v1/embeddings",
testModelKey: "some-model",
"authorization": "bearer original-token", // This will be removed, so won't be restored.
"x-api-key": "original-secret", // This will be removed, so won't be restored.
"x-custom": "original-custom", // This won't be removed, so can be restored.
"x-new-header": "original-value", // This will be set, so won't be restored.
}
p.headerMutator = headermutator.NewHeaderMutator(headerMutations, originalHeaders)

resp, err := p.ProcessRequestHeaders(t.Context(), nil)
require.NoError(t, err)
require.NotNil(t, resp)

commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestHeaders).RequestHeaders.Response

// Check that header mutations were applied.
require.NotNil(t, commonRes.HeaderMutation)
// RemoveHeaders should be empty because authorization/x-api-key don't exist in current headers.
require.Empty(t, commonRes.HeaderMutation.RemoveHeaders)
require.Len(t, commonRes.HeaderMutation.SetHeaders, 2) // Updated header + restored header.

// Check that x-custom header was restored on retry (it's not being removed or set).
var restoredHeader *corev3.HeaderValueOption
var updatedHeader *corev3.HeaderValueOption
for _, h := range commonRes.HeaderMutation.SetHeaders {
switch h.Header.Key {
case "x-custom":
restoredHeader = h
case "x-new-header":
updatedHeader = h
}
}
require.NotNil(t, restoredHeader)
require.Equal(t, []byte("original-custom"), restoredHeader.Header.RawValue)
require.NotNil(t, updatedHeader)
require.Equal(t, []byte("updated-value"), updatedHeader.Header.RawValue)

// Check that headers were updated in the request headers.
require.Equal(t, "updated-value", headers["x-new-header"])
require.Equal(t, "original-custom", headers["x-custom"])
})

t.Run("no header mutations when mutator is nil", func(t *testing.T) {
headers := map[string]string{
":path": "/v1/embeddings",
testModelKey: "some-model",
"authorization": "bearer token123",
":path": "/v1/embeddings",
internalapi.ModelNameHeaderKeyDefault: "some-model",
"authorization": "bearer token123",
}
someBody := embeddingBodyFromModel(t, "some-model")
var body openai.EmbeddingRequest
Expand Down Expand Up @@ -756,7 +680,7 @@ func TestEmbeddingsProcessorUpstreamFilter_SetBackend_WithHeaderMutations(t *tes

// Test retry scenario - original headers should be restored.
testHeaders := map[string]string{
"x-existing": "current-value", // This exists, so won't be restored.
"x-existing": "previously-set-value",
}
mutation := p.headerMutator.Mutate(testHeaders, true) // onRetry = true.

Expand All @@ -776,6 +700,6 @@ func TestEmbeddingsProcessorUpstreamFilter_SetBackend_WithHeaderMutations(t *tes
require.Equal(t, []byte("original-value"), restoredHeader.Header.RawValue)
require.Equal(t, "original-value", testHeaders["x-custom"])
// x-existing should not be restored because it already exists.
require.Equal(t, "current-value", testHeaders["x-existing"])
require.Equal(t, "existing-value", testHeaders["x-existing"])
})
}
59 changes: 55 additions & 4 deletions internal/extproc/headermutator/header_mutator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import (
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"

"github.com/envoyproxy/ai-gateway/internal/filterapi"
"github.com/envoyproxy/ai-gateway/internal/internalapi"
)

type HeaderMutator struct {
// getOrignalHeaders Callback to get removed sensitive headers from the router filter.
// getOriginalHeaders Callback to get removed sensitive headers from the router filter.
originalHeaders map[string]string

// headerMutations is a list of header mutations to apply.
Expand All @@ -40,6 +41,9 @@ func (h *HeaderMutator) Mutate(headers map[string]string, onRetry bool) *extproc
if !skipRemove {
for _, h := range h.headerMutations.Remove {
key := strings.ToLower(h)
if shouldIgnoreHeader(key) {
continue
}
removedHeadersSet[key] = struct{}{}
if _, ok := headers[key]; ok {
// Do NOT delete from the local headers map so metrics can still read it.
Expand All @@ -54,6 +58,9 @@ func (h *HeaderMutator) Mutate(headers map[string]string, onRetry bool) *extproc
if !skipSet {
for _, h := range h.headerMutations.Set {
key := strings.ToLower(h.Name)
if shouldIgnoreHeader(key) {
continue
}
setHeadersSet[key] = struct{}{}
headers[key] = h.Value
headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{
Expand All @@ -62,21 +69,65 @@ func (h *HeaderMutator) Mutate(headers map[string]string, onRetry bool) *extproc
}
}

// Restore original headers on retry, only if not being removed, set or not already present.
if onRetry && h.originalHeaders != nil {
if onRetry {
// Restore original headers on retry, only if not being removed, set or not already present.
for h, v := range h.originalHeaders {
key := strings.ToLower(h)
if shouldIgnoreHeader(key) {
continue
}
_, isRemoved := removedHeadersSet[key]
_, isSet := setHeadersSet[key]
_, exists := headers[key]
if !isRemoved && !exists && !isSet {
headers[h] = v
setHeadersSet[key] = struct{}{}
headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{Key: h, RawValue: []byte(v)},
})
}
}
// 1. Remove any headers that were added in the previous attempt (not part of original headers and not being set now).
// 2. Restore any original headers that were modified in the previous attempt (and not being set now).
for key := range headers {
key = strings.ToLower(key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just TOL, should we merge this loop and the loop over original headers? If I understand this loop (if no removals or additions) will make headers equal to orignalHeaders. Should we start with headers = originalHeaders and go ahead with remaining mutations?
Initially I skipped this to avoid any header that was already added and was different from original header before entering mutation but as we are making them same now just wondering if extra loop is required

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it in a follow up later

if shouldIgnoreHeader(key) {
continue
}
if _, set := setHeadersSet[key]; set {
continue
}
originalValue, exists := h.originalHeaders[key]
if !exists {
delete(headers, key)
headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, key)
} else {
// Restore original value.
headers[key] = originalValue
headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{Key: key, RawValue: []byte(originalValue)},
})
}
}
}

return headerMutation
}

// shouldIgnoreHeader returns true if the header key should be ignored for mutation.
//
// Skip Envoy AI Gateway headers since some of them are populated after the originalHeaders are captured.
// This should be safe since these headers are managed by Envoy AI Gateway itself, not expected to be
// modified by users via header mutation API.
//
// Also, skip Envoy pseudo-headers beginning with ':'.
func shouldIgnoreHeader(key string) bool {
// Ignore Envoy pseudo-headers beginning with ':'.
if strings.HasPrefix(key, ":") {
return true
}
// Ignore internal headers beginning with Envoy AI Gateway prefix.
if strings.HasPrefix(key, internalapi.EnvoyAIGatewayHeaderPrefix) {
return true
}
return false
}
32 changes: 28 additions & 4 deletions internal/extproc/headermutator/header_mutator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/envoyproxy/ai-gateway/internal/filterapi"
"github.com/envoyproxy/ai-gateway/internal/internalapi"
)

func TestHeaderMutator_Mutate(t *testing.T) {
Expand Down Expand Up @@ -41,13 +42,23 @@ func TestHeaderMutator_Mutate(t *testing.T) {

t.Run("restore original headers on retry", func(t *testing.T) {
originalHeaders := map[string]string{
"authorization": "secret",
"x-api-key": "key123",
"other": "value",
"authorization": "secret",
"x-api-key": "key123",
"other": "value",
"only-in-original": "original",
"in-original-too-but-previous-attempt-set": "pikachu",
// Envoy pseudo-header should be ignored.
":path": "/v1/endpoint",
// Internal headers should be ignored.
internalapi.EnvoyAIGatewayHeaderPrefix + "-foo-bar": "should-not-be-included",
}
headers := map[string]string{
"other": "value",
"authorization": "secret",
"in-original-too-but-previous-attempt-set": "charmander",
"only-set-previously": "bulbasaur",
// Internal headers should be ignored.
internalapi.EnvoyAIGatewayHeaderPrefix + "-dog-cat": "should-not-be-included",
}
mutations := &filterapi.HTTPHeaderMutation{
Remove: []string{"authorization"},
Expand All @@ -57,9 +68,22 @@ func TestHeaderMutator_Mutate(t *testing.T) {
mutation := mutator.Mutate(headers, true)

require.NotNil(t, mutation)
require.ElementsMatch(t, []string{"authorization"}, mutation.RemoveHeaders)
require.ElementsMatch(t, []string{"authorization", "only-set-previously"}, mutation.RemoveHeaders)
require.Len(t, mutation.SetHeaders, 5)
setHeadersMap := make(map[string]string)
for _, h := range mutation.SetHeaders {
setHeadersMap[h.Header.Key] = string(h.Header.RawValue)
}
require.Equal(t, "key123", setHeadersMap["x-api-key"])
require.Equal(t, "value", setHeadersMap["other"])
require.Equal(t, "secret", setHeadersMap["authorization"])
require.Equal(t, "original", setHeadersMap["only-in-original"])
require.Equal(t, "pikachu", setHeadersMap["in-original-too-but-previous-attempt-set"])
// Check the final headers map too.
require.Equal(t, "key123", headers["x-api-key"])
require.Equal(t, "value", headers["other"])
require.Equal(t, "secret", headers["authorization"])
require.Equal(t, "original", headers["only-in-original"])
require.Equal(t, "pikachu", headers["in-original-too-but-previous-attempt-set"])
})
}
Loading