Skip to content

Commit 6ae4e4a

Browse files
committed
wip
Signed-off-by: Takeshi Yoneda <[email protected]>
1 parent b122463 commit 6ae4e4a

File tree

4 files changed

+130
-110
lines changed

4 files changed

+130
-110
lines changed

cmd/dynamic_module/main.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"os"
1313
"time"
1414

15+
"github.com/envoyproxy/ai-gateway/internal/tracing"
1516
"github.com/prometheus/client_golang/prometheus"
1617
otelprom "go.opentelemetry.io/otel/exporters/prometheus"
1718

@@ -99,6 +100,17 @@ func (g *globalState) initializeEnv() error {
99100
if err != nil {
100101
return fmt.Errorf("failed to parse metrics header mapping: %w", err)
101102
}
103+
spanRequestHeaderAttributes, err := internalapi.ParseRequestHeaderAttributeMapping(os.Getenv(
104+
"AI_GATEWAY_DYNAMIC_MODULE_FILTER_TRACING_REQUEST_HEADER_ATTRIBUTES",
105+
))
106+
if err != nil {
107+
return fmt.Errorf("failed to parse tracing header mapping: %w", err)
108+
}
109+
110+
tracing, err := tracing.NewTracingFromEnv(ctx, os.Stdout, spanRequestHeaderAttributes)
111+
if err != nil {
112+
return err
113+
}
102114

103115
g.env = &dynamicmodule.Env{
104116
RootPrefix: os.Getenv("AI_GATEWAY_DYNAMIC_MODULE_ROOT_PREFIX"),
@@ -109,6 +121,7 @@ func (g *globalState) initializeEnv() error {
109121
EmbeddingsMetricsFactory: metrics.NewMetricsFactory(meter, metricsRequestHeaderAttributes, metrics.GenAIOperationEmbedding),
110122
ImageGenerationMetricsFactory: metrics.NewMetricsFactory(meter, metricsRequestHeaderAttributes, metrics.GenAIOperationImageGeneration),
111123
RerankMetricsFactory: metrics.NewMetricsFactory(meter, metricsRequestHeaderAttributes, metrics.GenAIOperationRerank),
124+
Tracing: tracing,
112125
}
113126
return nil
114127
}

internal/dynamicmodule/dynamic_module.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package dynamicmodule
88
import (
99
"github.com/envoyproxy/ai-gateway/internal/internalapi"
1010
"github.com/envoyproxy/ai-gateway/internal/metrics"
11+
tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api"
1112
)
1213

1314
// endpoint represents the type of the endpoint that the request is targeting.
@@ -40,4 +41,5 @@ type Env struct {
4041
EmbeddingsMetricsFactory,
4142
ImageGenerationMetricsFactory,
4243
RerankMetricsFactory metrics.Factory
44+
Tracing tracing.Tracing
4345
}

internal/dynamicmodule/router_filter.go

Lines changed: 115 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,22 @@
66
package dynamicmodule
77

88
import (
9+
"context"
910
"encoding/json"
1011
"fmt"
1112
"io"
1213
"path"
1314
"strings"
1415
"unsafe"
1516

16-
openaisdk "github.com/openai/openai-go/v2"
17-
1817
"github.com/envoyproxy/ai-gateway/internal/apischema/anthropic"
1918
cohereschema "github.com/envoyproxy/ai-gateway/internal/apischema/cohere"
2019
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
2120
"github.com/envoyproxy/ai-gateway/internal/dynamicmodule/sdk"
21+
"github.com/envoyproxy/ai-gateway/internal/endpointspec"
2222
"github.com/envoyproxy/ai-gateway/internal/filterapi"
2323
"github.com/envoyproxy/ai-gateway/internal/internalapi"
24+
tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api"
2425
)
2526

2627
const routerFilterPointerDynamicMetadataKey = "router_filter_pointer"
@@ -33,17 +34,41 @@ type (
3334
routerFilterConfig struct {
3435
fcr **filterapi.RuntimeConfig
3536
prefixToEndpoint map[string]endpoint
37+
tracing tracing.Tracing
3638
}
3739
// routerFilter implements [sdk.HTTPFilter].
3840
routerFilter struct {
39-
routerFilterConfig *routerFilterConfig
41+
// prefixToEndpoint maps request path prefixes to endpoints. Shallow copy of
42+
// the one in routerFilterConfig at the time of filter creation.
43+
prefixToEndpoint map[string]endpoint
44+
// runtimeFilterConfig is the snapshot of the runtime filter configuration at the time of filter creation.
45+
runtimeFilterConfig *filterapi.RuntimeConfig
46+
// tracing is the tracing implementation inherited from the environment.
47+
tracing tracing.Tracing
48+
49+
// endpoint is the endpoint that the current request is targeting.
50+
endpoint endpoint
51+
// typedFilter is the typed router filter for the current request.
52+
typedFilter routerFilterTypedIface
53+
}
54+
55+
// routerFilterTypedIface is the interface for the typed router filter.
56+
routerFilterTypedIface interface {
57+
RequestBody(e sdk.EnvoyHTTPFilter, endOfStream bool) sdk.RequestBodyStatus
58+
}
59+
60+
// routerFilter typed is the typed implementation of the router filter for a specific endpoint.
61+
routerFilterTyped[ReqT, RespT, RespChunkT any, EndpointSpec endpointspec.Spec[ReqT, RespT, RespChunkT]] struct {
4062
runtimeFilterConfig *filterapi.RuntimeConfig
41-
endpoint endpoint
42-
originalHeaders map[string]string
43-
originalRequestBody any
44-
originalRequestBodyRaw []byte
45-
span any
4663
attemptCount int
64+
ep EndpointSpec
65+
originalRequestHeaders map[string]string
66+
originalRequestBody *ReqT
67+
originalRequestBodyRaw []byte
68+
originalModel internalapi.OriginalModel
69+
stream bool
70+
tracer tracing.RequestTracer[ReqT, RespT, RespChunkT]
71+
span tracing.Span[RespT, RespChunkT]
4772
}
4873
)
4974

@@ -61,12 +86,17 @@ func NewRouterFilterConfig(env *Env, fcr **filterapi.RuntimeConfig) sdk.HTTPFilt
6186
return &routerFilterConfig{
6287
fcr: fcr,
6388
prefixToEndpoint: prefixToEndpoint,
89+
tracing: env.Tracing,
6490
}
6591
}
6692

6793
// NewFilter implements [sdk.HTTPFilterConfig].
6894
func (f *routerFilterConfig) NewFilter() sdk.HTTPFilter {
69-
return &routerFilter{routerFilterConfig: f, runtimeFilterConfig: *f.fcr}
95+
return &routerFilter{
96+
prefixToEndpoint: f.prefixToEndpoint,
97+
runtimeFilterConfig: *f.fcr,
98+
tracing: f.tracing,
99+
}
70100
}
71101

72102
// RequestHeaders implements [sdk.HTTPFilter].
@@ -76,7 +106,7 @@ func (f *routerFilter) RequestHeaders(e sdk.EnvoyHTTPFilter, _ bool) sdk.Request
76106
if queryIndex := strings.Index(p, "?"); queryIndex != -1 {
77107
p = p[:queryIndex]
78108
}
79-
ep, ok := f.routerFilterConfig.prefixToEndpoint[p]
109+
ep, ok := f.prefixToEndpoint[p]
80110
if !ok {
81111
e.SendLocalReply(404, nil, []byte(fmt.Sprintf("unsupported path: %s", p)))
82112
return sdk.RequestHeadersStatusStopIteration
@@ -90,6 +120,45 @@ func (f *routerFilter) RequestHeaders(e sdk.EnvoyHTTPFilter, _ bool) sdk.Request
90120

91121
// RequestBody implements [sdk.HTTPFilter].
92122
func (f *routerFilter) RequestBody(e sdk.EnvoyHTTPFilter, endOfStream bool) sdk.RequestBodyStatus {
123+
switch f.endpoint {
124+
case chatCompletionsEndpoint:
125+
f.typedFilter = &routerFilterTyped[openai.ChatCompletionRequest, openai.ChatCompletionResponse, openai.ChatCompletionResponseChunk, endpointspec.ChatCompletionsEndpointSpec]{
126+
runtimeFilterConfig: f.runtimeFilterConfig,
127+
tracer: f.tracing.ChatCompletionTracer(),
128+
}
129+
case completionsEndpoint:
130+
f.typedFilter = &routerFilterTyped[openai.CompletionRequest, openai.CompletionResponse, openai.CompletionResponse, endpointspec.CompletionsEndpointSpec]{
131+
runtimeFilterConfig: f.runtimeFilterConfig,
132+
tracer: f.tracing.CompletionTracer(),
133+
}
134+
case embeddingsEndpoint:
135+
f.typedFilter = &routerFilterTyped[openai.EmbeddingRequest, openai.EmbeddingResponse, struct{}, endpointspec.EmbeddingsEndpointSpec]{
136+
runtimeFilterConfig: f.runtimeFilterConfig,
137+
tracer: f.tracing.EmbeddingsTracer(),
138+
}
139+
case imagesGenerationsEndpoint:
140+
f.typedFilter = &routerFilterTyped[openai.ImageGenerationRequest, openai.ImageGenerationResponse, struct{}, endpointspec.ImageGenerationEndpointSpec]{
141+
runtimeFilterConfig: f.runtimeFilterConfig,
142+
tracer: f.tracing.ImageGenerationTracer(),
143+
}
144+
case rerankEndpoint:
145+
f.typedFilter = &routerFilterTyped[cohereschema.RerankV2Request, cohereschema.RerankV2Response, struct{}, endpointspec.RerankEndpointSpec]{
146+
runtimeFilterConfig: f.runtimeFilterConfig,
147+
tracer: f.tracing.RerankTracer(),
148+
}
149+
case messagesEndpoint:
150+
f.typedFilter = &routerFilterTyped[anthropic.MessagesRequest, anthropic.MessagesResponse, anthropic.MessagesStreamChunk, endpointspec.MessagesEndpointSpec]{
151+
runtimeFilterConfig: f.runtimeFilterConfig,
152+
tracer: f.tracing.MessageTracer(),
153+
}
154+
default:
155+
e.SendLocalReply(500, nil, []byte("BUG: unsupported endpoint at body parsing: "+fmt.Sprintf("%d", f.endpoint)))
156+
return sdk.RequestBodyStatusStopIterationAndBuffer
157+
}
158+
return f.typedFilter.RequestBody(e, endOfStream)
159+
}
160+
161+
func (f *routerFilterTyped[ReqT, RespT, RespChunkT, EndpointSpecT]) RequestBody(e sdk.EnvoyHTTPFilter, endOfStream bool) sdk.RequestBodyStatus {
93162
if !endOfStream {
94163
return sdk.RequestBodyStatusStopIterationAndBuffer
95164
}
@@ -103,39 +172,31 @@ func (f *routerFilter) RequestBody(e sdk.EnvoyHTTPFilter, endOfStream bool) sdk.
103172
e.SendLocalReply(400, nil, []byte("failed to read request body: "+err.Error()))
104173
return sdk.RequestBodyStatusStopIterationAndBuffer
105174
}
175+
106176
f.originalRequestBodyRaw = raw
107-
var parsed any
108-
var modelName string
109-
switch f.endpoint {
110-
case chatCompletionsEndpoint:
111-
parsed, modelName, err = parseBodyWithModel(raw, func(req *openai.ChatCompletionRequest) string { return req.Model })
112-
case completionsEndpoint:
113-
parsed, modelName, err = parseBodyWithModel(raw, func(req *openai.CompletionRequest) string { return req.Model })
114-
case embeddingsEndpoint:
115-
parsed, modelName, err = parseBodyWithModel(raw, func(req *openai.EmbeddingRequest) string { return req.Model })
116-
case imagesGenerationsEndpoint:
117-
parsed, modelName, err = parseBodyWithModel(raw, func(req *openaisdk.ImageGenerateParams) string { return req.Model })
118-
case rerankEndpoint:
119-
parsed, modelName, err = parseBodyWithModel(raw, func(req *cohereschema.RerankV2Request) string { return req.Model })
120-
case messagesEndpoint:
121-
parsed, modelName, err = parseBodyWithModel(raw, func(req *anthropic.MessagesRequest) string { return req.GetModel() })
122-
default:
123-
e.SendLocalReply(500, nil, []byte("BUG: unsupported endpoint at body parsing: "+fmt.Sprintf("%d", f.endpoint)))
124-
}
125-
if err != nil {
126-
e.SendLocalReply(400, nil, []byte("failed to parse request body: "+err.Error()))
127-
return sdk.RequestBodyStatusStopIterationAndBuffer
177+
var maybeMutatedOriginalBodyRaw []byte
178+
f.originalModel, f.originalRequestBody, f.stream, maybeMutatedOriginalBodyRaw, err =
179+
f.ep.ParseBody(raw, len(f.runtimeFilterConfig.RequestCosts) > 0)
180+
if len(maybeMutatedOriginalBodyRaw) > 0 {
181+
f.originalRequestBodyRaw = maybeMutatedOriginalBodyRaw
128182
}
129-
f.originalRequestBody = parsed
130-
if !e.SetRequestHeader(internalapi.ModelNameHeaderKeyDefault, []byte(modelName)) {
183+
if !e.SetRequestHeader(internalapi.ModelNameHeaderKeyDefault, []byte(f.originalModel)) {
131184
e.SendLocalReply(500, nil, []byte("failed to set model name header"))
132185
return sdk.RequestBodyStatusStopIterationAndBuffer
133186
}
134187
// Store the pointer to the filter in dynamic metadata for later retrieval in the upstream filter.
135188
e.SetDynamicMetadataString(internalapi.AIGatewayFilterMetadataNamespace, routerFilterPointerDynamicMetadataKey,
136189
fmt.Sprintf("%d", uintptr(unsafe.Pointer(f))))
137190

138-
f.originalHeaders = multiValueHeadersToSingleValue(e.GetRequestHeaders())
191+
f.originalRequestHeaders = multiValueHeadersToSingleValue(e.GetRequestHeaders())
192+
193+
f.span = f.tracer.StartSpanAndInjectHeaders(
194+
context.Background(),
195+
f.originalRequestHeaders,
196+
&headerMutationCarrier{e: e},
197+
f.originalRequestBody,
198+
f.originalRequestBodyRaw,
199+
)
139200
return sdk.RequestBodyStatusContinue
140201
}
141202

@@ -174,14 +235,6 @@ func (f *routerFilter) handleModelsEndpoint(e sdk.EnvoyHTTPFilter) sdk.RequestHe
174235
return sdk.RequestHeadersStatusStopIteration
175236
}
176237

177-
func parseBodyWithModel[T any](body []byte, modelExtractFn func(req *T) string) (interface{}, string, error) {
178-
var req T
179-
if err := json.Unmarshal(body, &req); err != nil {
180-
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
181-
}
182-
return req, modelExtractFn(&req), nil
183-
}
184-
185238
// multiValueHeadersToSingleValue converts a map of headers with multiple values to a map of headers with single values by taking the first value for each header.
186239
//
187240
// TODO: this is purely for feature parity with the old filter where we ignore the case of multiple header values.
@@ -192,3 +245,23 @@ func multiValueHeadersToSingleValue(headers map[string][]string) map[string]stri
192245
}
193246
return singleValueHeaders
194247
}
248+
249+
// headerMutationCarrier implements [propagation.TextMapCarrier].
250+
type headerMutationCarrier struct {
251+
e sdk.EnvoyHTTPFilter
252+
}
253+
254+
// Get implements the same method as defined on propagation.TextMapCarrier.
255+
func (c *headerMutationCarrier) Get(string) string {
256+
panic("unexpected as this carrier is write-only for injection")
257+
}
258+
259+
// Set adds a key-value pair to the HeaderMutation.
260+
func (c *headerMutationCarrier) Set(key, value string) {
261+
_ = c.e.SetResponseHeader(key, []byte(value))
262+
}
263+
264+
// Keys implements the same method as defined on propagation.TextMapCarrier.
265+
func (c *headerMutationCarrier) Keys() []string {
266+
panic("unexpected as this carrier is write-only for injection")
267+
}

internal/dynamicmodule/upstream_filter.go

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -194,71 +194,3 @@ func (f *upstreamFilter) ResponseHeaders(e sdk.EnvoyHTTPFilter, _ bool) sdk.Resp
194194
func (f *upstreamFilter) ResponseBody(sdk.EnvoyHTTPFilter, bool) sdk.ResponseBodyStatus {
195195
return sdk.ResponseBodyStatusContinue
196196
}
197-
198-
func (f *upstreamFilter) initializeTranslatorMetrics(b *filterapi.Backend) error {
199-
out := b.Schema
200-
modelNameOverride := b.ModelNameOverride
201-
switch f.rf.endpoint {
202-
case chatCompletionsEndpoint:
203-
switch out.Name {
204-
case filterapi.APISchemaOpenAI:
205-
f.translator = translator.NewChatCompletionOpenAIToOpenAITranslator(out.Version, modelNameOverride)
206-
case filterapi.APISchemaAWSBedrock:
207-
f.translator = translator.NewChatCompletionOpenAIToAWSBedrockTranslator(modelNameOverride)
208-
case filterapi.APISchemaAzureOpenAI:
209-
f.translator = translator.NewChatCompletionOpenAIToAzureOpenAITranslator(out.Version, modelNameOverride)
210-
case filterapi.APISchemaGCPVertexAI:
211-
f.translator = translator.NewChatCompletionOpenAIToGCPVertexAITranslator(modelNameOverride)
212-
case filterapi.APISchemaGCPAnthropic:
213-
f.translator = translator.NewChatCompletionOpenAIToGCPAnthropicTranslator(out.Version, modelNameOverride)
214-
default:
215-
return fmt.Errorf("unsupported API schema: backend=%s", out)
216-
}
217-
f.metrics = f.env.ChatCompletionMetricsFactory.NewMetrics()
218-
case completionsEndpoint:
219-
switch out.Name {
220-
case filterapi.APISchemaOpenAI:
221-
f.translator = translator.NewChatCompletionOpenAIToOpenAITranslator(out.Version, modelNameOverride)
222-
default:
223-
return fmt.Errorf("unsupported API schema: backend=%s", out)
224-
}
225-
f.metrics = f.env.CompletionMetricsFactory.NewMetrics()
226-
case embeddingsEndpoint:
227-
switch out.Name {
228-
case filterapi.APISchemaOpenAI:
229-
f.translator = translator.NewEmbeddingOpenAIToOpenAITranslator(out.Version, modelNameOverride)
230-
case filterapi.APISchemaAzureOpenAI:
231-
f.translator = translator.NewEmbeddingOpenAIToAzureOpenAITranslator(out.Version, modelNameOverride)
232-
default:
233-
return fmt.Errorf("unsupported API schema: backend=%s", out)
234-
}
235-
f.metrics = f.env.CompletionMetricsFactory.NewMetrics()
236-
case imagesGenerationsEndpoint:
237-
switch out.Name {
238-
case filterapi.APISchemaOpenAI:
239-
f.translator = translator.NewImageGenerationOpenAIToOpenAITranslator(out.Version, modelNameOverride)
240-
default:
241-
return fmt.Errorf("unsupported API schema: backend=%s", out)
242-
}
243-
f.metrics = f.env.CompletionMetricsFactory.NewMetrics()
244-
case rerankEndpoint:
245-
switch out.Name {
246-
case filterapi.APISchemaCohere:
247-
f.translator = translator.NewRerankCohereToCohereTranslator(out.Version, modelNameOverride)
248-
default:
249-
return fmt.Errorf("unsupported API schema: backend=%s", out)
250-
}
251-
f.metrics = f.env.RerankMetricsFactory.NewMetrics()
252-
case messagesEndpoint:
253-
switch out.Name {
254-
case filterapi.APISchemaAnthropic:
255-
f.translator = translator.NewAnthropicToAnthropicTranslator(out.Version, modelNameOverride)
256-
default:
257-
return fmt.Errorf("unsupported API schema: backend=%s", out)
258-
}
259-
f.metrics = f.env.MessagesMetricsFactory.NewMetrics()
260-
default:
261-
return fmt.Errorf("unsupported endpoint for per-route upstream filter: %v", f.rf.endpoint)
262-
}
263-
return nil
264-
}

0 commit comments

Comments
 (0)