66package dynamicmodule
77
88import (
9+ "context"
910 "encoding/json"
1011 "fmt"
1112 "io"
1213 "path"
1314 "strings"
1415 "unsafe"
1516
16- openaisdk "github.com/openai/openai-go/v2"
17-
1817 "github.com/envoyproxy/ai-gateway/internal/apischema/anthropic"
1918 cohereschema "github.com/envoyproxy/ai-gateway/internal/apischema/cohere"
2019 "github.com/envoyproxy/ai-gateway/internal/apischema/openai"
2120 "github.com/envoyproxy/ai-gateway/internal/dynamicmodule/sdk"
21+ "github.com/envoyproxy/ai-gateway/internal/endpointspec"
2222 "github.com/envoyproxy/ai-gateway/internal/filterapi"
2323 "github.com/envoyproxy/ai-gateway/internal/internalapi"
24+ tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api"
2425)
2526
2627const routerFilterPointerDynamicMetadataKey = "router_filter_pointer"
@@ -33,17 +34,41 @@ type (
3334 routerFilterConfig struct {
3435 fcr * * filterapi.RuntimeConfig
3536 prefixToEndpoint map [string ]endpoint
37+ tracing tracing.Tracing
3638 }
3739 // routerFilter implements [sdk.HTTPFilter].
3840 routerFilter struct {
39- routerFilterConfig * routerFilterConfig
41+ // prefixToEndpoint maps request path prefixes to endpoints. Shallow copy of
42+ // the one in routerFilterConfig at the time of filter creation.
43+ prefixToEndpoint map [string ]endpoint
44+ // runtimeFilterConfig is the snapshot of the runtime filter configuration at the time of filter creation.
45+ runtimeFilterConfig * filterapi.RuntimeConfig
46+ // tracing is the tracing implementation inherited from the environment.
47+ tracing tracing.Tracing
48+
49+ // endpoint is the endpoint that the current request is targeting.
50+ endpoint endpoint
51+ // typedFilter is the typed router filter for the current request.
52+ typedFilter routerFilterTypedIface
53+ }
54+
55+ // routerFilterTypedIface is the interface for the typed router filter.
56+ routerFilterTypedIface interface {
57+ RequestBody (e sdk.EnvoyHTTPFilter , endOfStream bool ) sdk.RequestBodyStatus
58+ }
59+
60+ // routerFilter typed is the typed implementation of the router filter for a specific endpoint.
61+ routerFilterTyped [ReqT , RespT , RespChunkT any , EndpointSpec endpointspec.Spec [ReqT , RespT , RespChunkT ]] struct {
4062 runtimeFilterConfig * filterapi.RuntimeConfig
41- endpoint endpoint
42- originalHeaders map [string ]string
43- originalRequestBody any
44- originalRequestBodyRaw []byte
45- span any
4663 attemptCount int
64+ ep EndpointSpec
65+ originalRequestHeaders map [string ]string
66+ originalRequestBody * ReqT
67+ originalRequestBodyRaw []byte
68+ originalModel internalapi.OriginalModel
69+ stream bool
70+ tracer tracing.RequestTracer [ReqT , RespT , RespChunkT ]
71+ span tracing.Span [RespT , RespChunkT ]
4772 }
4873)
4974
@@ -61,12 +86,17 @@ func NewRouterFilterConfig(env *Env, fcr **filterapi.RuntimeConfig) sdk.HTTPFilt
6186 return & routerFilterConfig {
6287 fcr : fcr ,
6388 prefixToEndpoint : prefixToEndpoint ,
89+ tracing : env .Tracing ,
6490 }
6591}
6692
6793// NewFilter implements [sdk.HTTPFilterConfig].
6894func (f * routerFilterConfig ) NewFilter () sdk.HTTPFilter {
69- return & routerFilter {routerFilterConfig : f , runtimeFilterConfig : * f .fcr }
95+ return & routerFilter {
96+ prefixToEndpoint : f .prefixToEndpoint ,
97+ runtimeFilterConfig : * f .fcr ,
98+ tracing : f .tracing ,
99+ }
70100}
71101
72102// RequestHeaders implements [sdk.HTTPFilter].
@@ -76,7 +106,7 @@ func (f *routerFilter) RequestHeaders(e sdk.EnvoyHTTPFilter, _ bool) sdk.Request
76106 if queryIndex := strings .Index (p , "?" ); queryIndex != - 1 {
77107 p = p [:queryIndex ]
78108 }
79- ep , ok := f .routerFilterConfig . prefixToEndpoint [p ]
109+ ep , ok := f .prefixToEndpoint [p ]
80110 if ! ok {
81111 e .SendLocalReply (404 , nil , []byte (fmt .Sprintf ("unsupported path: %s" , p )))
82112 return sdk .RequestHeadersStatusStopIteration
@@ -90,6 +120,45 @@ func (f *routerFilter) RequestHeaders(e sdk.EnvoyHTTPFilter, _ bool) sdk.Request
90120
91121// RequestBody implements [sdk.HTTPFilter].
92122func (f * routerFilter ) RequestBody (e sdk.EnvoyHTTPFilter , endOfStream bool ) sdk.RequestBodyStatus {
123+ switch f .endpoint {
124+ case chatCompletionsEndpoint :
125+ f .typedFilter = & routerFilterTyped [openai.ChatCompletionRequest , openai.ChatCompletionResponse , openai.ChatCompletionResponseChunk , endpointspec.ChatCompletionsEndpointSpec ]{
126+ runtimeFilterConfig : f .runtimeFilterConfig ,
127+ tracer : f .tracing .ChatCompletionTracer (),
128+ }
129+ case completionsEndpoint :
130+ f .typedFilter = & routerFilterTyped [openai.CompletionRequest , openai.CompletionResponse , openai.CompletionResponse , endpointspec.CompletionsEndpointSpec ]{
131+ runtimeFilterConfig : f .runtimeFilterConfig ,
132+ tracer : f .tracing .CompletionTracer (),
133+ }
134+ case embeddingsEndpoint :
135+ f .typedFilter = & routerFilterTyped [openai.EmbeddingRequest , openai.EmbeddingResponse , struct {}, endpointspec.EmbeddingsEndpointSpec ]{
136+ runtimeFilterConfig : f .runtimeFilterConfig ,
137+ tracer : f .tracing .EmbeddingsTracer (),
138+ }
139+ case imagesGenerationsEndpoint :
140+ f .typedFilter = & routerFilterTyped [openai.ImageGenerationRequest , openai.ImageGenerationResponse , struct {}, endpointspec.ImageGenerationEndpointSpec ]{
141+ runtimeFilterConfig : f .runtimeFilterConfig ,
142+ tracer : f .tracing .ImageGenerationTracer (),
143+ }
144+ case rerankEndpoint :
145+ f .typedFilter = & routerFilterTyped [cohereschema.RerankV2Request , cohereschema.RerankV2Response , struct {}, endpointspec.RerankEndpointSpec ]{
146+ runtimeFilterConfig : f .runtimeFilterConfig ,
147+ tracer : f .tracing .RerankTracer (),
148+ }
149+ case messagesEndpoint :
150+ f .typedFilter = & routerFilterTyped [anthropic.MessagesRequest , anthropic.MessagesResponse , anthropic.MessagesStreamChunk , endpointspec.MessagesEndpointSpec ]{
151+ runtimeFilterConfig : f .runtimeFilterConfig ,
152+ tracer : f .tracing .MessageTracer (),
153+ }
154+ default :
155+ e .SendLocalReply (500 , nil , []byte ("BUG: unsupported endpoint at body parsing: " + fmt .Sprintf ("%d" , f .endpoint )))
156+ return sdk .RequestBodyStatusStopIterationAndBuffer
157+ }
158+ return f .typedFilter .RequestBody (e , endOfStream )
159+ }
160+
161+ func (f * routerFilterTyped [ReqT , RespT , RespChunkT , EndpointSpecT ]) RequestBody (e sdk.EnvoyHTTPFilter , endOfStream bool ) sdk.RequestBodyStatus {
93162 if ! endOfStream {
94163 return sdk .RequestBodyStatusStopIterationAndBuffer
95164 }
@@ -103,39 +172,31 @@ func (f *routerFilter) RequestBody(e sdk.EnvoyHTTPFilter, endOfStream bool) sdk.
103172 e .SendLocalReply (400 , nil , []byte ("failed to read request body: " + err .Error ()))
104173 return sdk .RequestBodyStatusStopIterationAndBuffer
105174 }
175+
106176 f .originalRequestBodyRaw = raw
107- var parsed any
108- var modelName string
109- switch f .endpoint {
110- case chatCompletionsEndpoint :
111- parsed , modelName , err = parseBodyWithModel (raw , func (req * openai.ChatCompletionRequest ) string { return req .Model })
112- case completionsEndpoint :
113- parsed , modelName , err = parseBodyWithModel (raw , func (req * openai.CompletionRequest ) string { return req .Model })
114- case embeddingsEndpoint :
115- parsed , modelName , err = parseBodyWithModel (raw , func (req * openai.EmbeddingRequest ) string { return req .Model })
116- case imagesGenerationsEndpoint :
117- parsed , modelName , err = parseBodyWithModel (raw , func (req * openaisdk.ImageGenerateParams ) string { return req .Model })
118- case rerankEndpoint :
119- parsed , modelName , err = parseBodyWithModel (raw , func (req * cohereschema.RerankV2Request ) string { return req .Model })
120- case messagesEndpoint :
121- parsed , modelName , err = parseBodyWithModel (raw , func (req * anthropic.MessagesRequest ) string { return req .GetModel () })
122- default :
123- e .SendLocalReply (500 , nil , []byte ("BUG: unsupported endpoint at body parsing: " + fmt .Sprintf ("%d" , f .endpoint )))
124- }
125- if err != nil {
126- e .SendLocalReply (400 , nil , []byte ("failed to parse request body: " + err .Error ()))
127- return sdk .RequestBodyStatusStopIterationAndBuffer
177+ var maybeMutatedOriginalBodyRaw []byte
178+ f .originalModel , f .originalRequestBody , f .stream , maybeMutatedOriginalBodyRaw , err =
179+ f .ep .ParseBody (raw , len (f .runtimeFilterConfig .RequestCosts ) > 0 )
180+ if len (maybeMutatedOriginalBodyRaw ) > 0 {
181+ f .originalRequestBodyRaw = maybeMutatedOriginalBodyRaw
128182 }
129- f .originalRequestBody = parsed
130- if ! e .SetRequestHeader (internalapi .ModelNameHeaderKeyDefault , []byte (modelName )) {
183+ if ! e .SetRequestHeader (internalapi .ModelNameHeaderKeyDefault , []byte (f .originalModel )) {
131184 e .SendLocalReply (500 , nil , []byte ("failed to set model name header" ))
132185 return sdk .RequestBodyStatusStopIterationAndBuffer
133186 }
134187 // Store the pointer to the filter in dynamic metadata for later retrieval in the upstream filter.
135188 e .SetDynamicMetadataString (internalapi .AIGatewayFilterMetadataNamespace , routerFilterPointerDynamicMetadataKey ,
136189 fmt .Sprintf ("%d" , uintptr (unsafe .Pointer (f ))))
137190
138- f .originalHeaders = multiValueHeadersToSingleValue (e .GetRequestHeaders ())
191+ f .originalRequestHeaders = multiValueHeadersToSingleValue (e .GetRequestHeaders ())
192+
193+ f .span = f .tracer .StartSpanAndInjectHeaders (
194+ context .Background (),
195+ f .originalRequestHeaders ,
196+ & headerMutationCarrier {e : e },
197+ f .originalRequestBody ,
198+ f .originalRequestBodyRaw ,
199+ )
139200 return sdk .RequestBodyStatusContinue
140201}
141202
@@ -174,14 +235,6 @@ func (f *routerFilter) handleModelsEndpoint(e sdk.EnvoyHTTPFilter) sdk.RequestHe
174235 return sdk .RequestHeadersStatusStopIteration
175236}
176237
177- func parseBodyWithModel [T any ](body []byte , modelExtractFn func (req * T ) string ) (interface {}, string , error ) {
178- var req T
179- if err := json .Unmarshal (body , & req ); err != nil {
180- return nil , "" , fmt .Errorf ("failed to unmarshal body: %w" , err )
181- }
182- return req , modelExtractFn (& req ), nil
183- }
184-
185238// multiValueHeadersToSingleValue converts a map of headers with multiple values to a map of headers with single values by taking the first value for each header.
186239//
187240// TODO: this is purely for feature parity with the old filter where we ignore the case of multiple header values.
@@ -192,3 +245,23 @@ func multiValueHeadersToSingleValue(headers map[string][]string) map[string]stri
192245 }
193246 return singleValueHeaders
194247}
248+
249+ // headerMutationCarrier implements [propagation.TextMapCarrier].
250+ type headerMutationCarrier struct {
251+ e sdk.EnvoyHTTPFilter
252+ }
253+
254+ // Get implements the same method as defined on propagation.TextMapCarrier.
255+ func (c * headerMutationCarrier ) Get (string ) string {
256+ panic ("unexpected as this carrier is write-only for injection" )
257+ }
258+
259+ // Set adds a key-value pair to the HeaderMutation.
260+ func (c * headerMutationCarrier ) Set (key , value string ) {
261+ _ = c .e .SetResponseHeader (key , []byte (value ))
262+ }
263+
264+ // Keys implements the same method as defined on propagation.TextMapCarrier.
265+ func (c * headerMutationCarrier ) Keys () []string {
266+ panic ("unexpected as this carrier is write-only for injection" )
267+ }
0 commit comments