Skip to content

Commit db8fa38

Browse files
committed
more
Signed-off-by: Takeshi Yoneda <[email protected]>
1 parent 9455f12 commit db8fa38

File tree

2 files changed

+82
-9
lines changed

2 files changed

+82
-9
lines changed

internal/dynamic_module/router_filter.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ type (
3131
routerFilter struct {
3232
fc *routerFilterConfig
3333
endpoint endpoint
34+
originalHeaders map[string]string
3435
originalRequestBody interface{}
3536
originalRequestBodyRaw []byte
37+
attemptCount int
3638
}
3739
)
3840

@@ -118,17 +120,17 @@ func (f *routerFilter) RequestBody(e sdk.EnvoyHTTPFilter, endOfStream bool) sdk.
118120
var modelName string
119121
switch f.endpoint {
120122
case chatCompletionsEndpoint:
121-
parsed, modelName, err = modelBodyParser(raw, func(req *openai.ChatCompletionRequest) string { return req.Model })
123+
parsed, modelName, err = parseBodyWithModel(raw, func(req *openai.ChatCompletionRequest) string { return req.Model })
122124
case completionsEndpoint:
123-
parsed, modelName, err = modelBodyParser(raw, func(req *openai.CompletionRequest) string { return req.Model })
125+
parsed, modelName, err = parseBodyWithModel(raw, func(req *openai.CompletionRequest) string { return req.Model })
124126
case embeddingsEndpoint:
125-
parsed, modelName, err = modelBodyParser(raw, func(req *openai.EmbeddingRequest) string { return req.Model })
127+
parsed, modelName, err = parseBodyWithModel(raw, func(req *openai.EmbeddingRequest) string { return req.Model })
126128
case imagesGenerationsEndpoint:
127-
parsed, modelName, err = modelBodyParser(raw, func(req *openaisdk.ImageGenerateParams) string { return req.Model })
129+
parsed, modelName, err = parseBodyWithModel(raw, func(req *openaisdk.ImageGenerateParams) string { return req.Model })
128130
case rerankEndpoint:
129-
parsed, modelName, err = modelBodyParser(raw, func(req *cohereschema.RerankV2Request) string { return req.Model })
131+
parsed, modelName, err = parseBodyWithModel(raw, func(req *cohereschema.RerankV2Request) string { return req.Model })
130132
case messagesEndpoint:
131-
parsed, modelName, err = modelBodyParser(raw, func(req *anthropic.MessagesRequest) string { return req.GetModel() })
133+
parsed, modelName, err = parseBodyWithModel(raw, func(req *anthropic.MessagesRequest) string { return req.GetModel() })
132134
default:
133135
e.SendLocalReply(500, nil, []byte("BUG: unsupported endpoint at body parsing: "+fmt.Sprintf("%d", f.endpoint)))
134136
}
@@ -144,6 +146,8 @@ func (f *routerFilter) RequestBody(e sdk.EnvoyHTTPFilter, endOfStream bool) sdk.
144146
// Store the pointer to the filter in dynamic metadata for later retrieval in the upstream filter.
145147
e.SetDynamicMetadataString(internalapi.AIGatewayFilterMetadataNamespace, routerFilterPointerDynamicMetadataKey,
146148
fmt.Sprintf("%d", uintptr(unsafe.Pointer(f))))
149+
150+
f.originalHeaders = multiValueHeadersToSingleValue(e.GetRequestHeaders())
147151
return sdk.RequestBodyStatusContinue
148152
}
149153

@@ -168,10 +172,21 @@ func (f *routerFilter) handleModelsEndpoint(e sdk.EnvoyHTTPFilter) sdk.RequestHe
168172
return sdk.RequestHeadersStatusStopIteration
169173
}
170174

171-
func modelBodyParser[T any](body []byte, modelExtractFn func(req *T) string) (interface{}, string, error) {
175+
func parseBodyWithModel[T any](body []byte, modelExtractFn func(req *T) string) (interface{}, string, error) {
172176
var req T
173177
if err := json.Unmarshal(body, &req); err != nil {
174178
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
175179
}
176180
return req, modelExtractFn(&req), nil
177181
}
182+
183+
// multiValueHeadersToSingleValue converts a map of headers with multiple values to a map of headers with single values by taking the first value for each header.
184+
//
185+
// TODO: this is purely for feature parity with the old filter where we ignore the case of multiple header values.
186+
func multiValueHeadersToSingleValue(headers map[string][]string) map[string]string {
187+
singleValueHeaders := make(map[string]string, len(headers))
188+
for k, v := range headers {
189+
singleValueHeaders[k] = v[0]
190+
}
191+
return singleValueHeaders
192+
}

internal/dynamic_module/upstream_filter.go

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package main
22

33
import (
4+
"context"
45
"fmt"
56
"strconv"
67
"unsafe"
78

89
"github.com/envoyproxy/ai-gateway/internal/dynamic_module/sdk"
10+
"github.com/envoyproxy/ai-gateway/internal/filterapi"
11+
"github.com/envoyproxy/ai-gateway/internal/headermutator"
912
"github.com/envoyproxy/ai-gateway/internal/internalapi"
1013
)
1114

@@ -14,7 +17,10 @@ type (
1417
upstreamFilterConfig struct{}
1518
// upstreamFilter implements [sdk.HTTPFilter].
1619
upstreamFilter struct {
17-
rf *routerFilter
20+
rf *routerFilter
21+
backend *filterapi.RuntimeBackend
22+
reqHeaders map[string]string
23+
onRetry bool
1824
}
1925
)
2026

@@ -38,6 +44,8 @@ func (f *upstreamFilter) RequestHeaders(e sdk.EnvoyHTTPFilter, _ bool) sdk.Reque
3844
return sdk.RequestHeadersStatusStopIteration
3945
}
4046
f.rf = (*routerFilter)(unsafe.Pointer(uintptr(rfPtr)))
47+
f.rf.attemptCount++
48+
f.onRetry = f.rf.attemptCount > 1
4149

4250
backend, ok := e.GetUpstreamHostMetadataString(internalapi.AIGatewayFilterMetadataNamespace, internalapi.InternalMetadataBackendNameKey)
4351
if !ok {
@@ -49,11 +57,61 @@ func (f *upstreamFilter) RequestHeaders(e sdk.EnvoyHTTPFilter, _ bool) sdk.Reque
4957
e.SendLocalReply(500, nil, []byte(fmt.Sprintf("backend %s not found in filter config", backend)))
5058
return sdk.RequestHeadersStatusStopIteration
5159
}
60+
61+
f.backend = b
62+
f.reqHeaders = multiValueHeadersToSingleValue(e.GetRequestHeaders())
63+
64+
// Now mutate the headers based on the backend configuration.
65+
if hm := b.Backend.HeaderMutation; hm != nil {
66+
sets, removes := headermutator.NewHeaderMutator(b.Backend.HeaderMutation, f.rf.originalHeaders).Mutate(f.reqHeaders, f.onRetry)
67+
for _, h := range sets {
68+
if !e.SetRequestHeader(h.Key(), []byte(h.Value())) {
69+
e.SendLocalReply(500, nil, []byte(fmt.Sprintf("failed to set header %s", h.Key())))
70+
return sdk.RequestHeadersStatusStopIteration
71+
}
72+
f.reqHeaders[h.Key()] = h.Value()
73+
}
74+
for _, key := range removes {
75+
if !e.SetRequestHeader(key, nil) {
76+
e.SendLocalReply(500, nil, []byte(fmt.Sprintf("failed to remove header %s", key)))
77+
return sdk.RequestHeadersStatusStopIteration
78+
}
79+
delete(f.reqHeaders, key)
80+
}
81+
}
5282
return sdk.RequestHeadersStatusContinue
5383
}
5484

5585
// RequestBody implements [sdk.HTTPFilter].
56-
func (f *upstreamFilter) RequestBody(sdk.EnvoyHTTPFilter, bool) sdk.RequestBodyStatus {
86+
func (f *upstreamFilter) RequestBody(e sdk.EnvoyHTTPFilter, endOfStream bool) sdk.RequestBodyStatus {
87+
if !endOfStream {
88+
// TODO: ideally, we should not buffer the entire body for the passthrough case.
89+
return sdk.RequestBodyStatusStopIterationAndBuffer
90+
}
91+
92+
b := f.backend
93+
94+
// TODO: endpoint specific logic such as translation.
95+
96+
if bm := b.Backend.BodyMutation; bm != nil {
97+
// TODO: body mutation if needed.
98+
_ = bm
99+
}
100+
101+
// Next is to do the upstream auth if needed.
102+
if b.Handler != nil {
103+
authHeaders, err := b.Handler.Do(context.Background(), f.reqHeaders, f.rf.originalRequestBodyRaw)
104+
if err != nil {
105+
e.SendLocalReply(500, nil, []byte(fmt.Sprintf("failed to do backend auth: %v", err)))
106+
return sdk.RequestBodyStatusStopIterationAndBuffer
107+
}
108+
for _, h := range authHeaders {
109+
if !e.SetRequestHeader(h.Key(), []byte(h.Value())) {
110+
e.SendLocalReply(500, nil, []byte(fmt.Sprintf("failed to set auth header %s", h.Key())))
111+
return sdk.RequestBodyStatusStopIterationAndBuffer
112+
}
113+
}
114+
}
57115
return sdk.RequestBodyStatusContinue
58116
}
59117

0 commit comments

Comments
 (0)