|
| 1 | +// Copyright Envoy AI Gateway Authors |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// The full text of the Apache license is available in the LICENSE file at |
| 4 | +// the root of the repo. |
| 5 | + |
| 6 | +package dynamicmodule |
| 7 | + |
| 8 | +import ( |
| 9 | + "encoding/json" |
| 10 | + "fmt" |
| 11 | + "io" |
| 12 | + "path" |
| 13 | + "strings" |
| 14 | + "unsafe" |
| 15 | + |
| 16 | + openaisdk "github.com/openai/openai-go/v2" |
| 17 | + |
| 18 | + "github.com/envoyproxy/ai-gateway/internal/apischema/anthropic" |
| 19 | + cohereschema "github.com/envoyproxy/ai-gateway/internal/apischema/cohere" |
| 20 | + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" |
| 21 | + "github.com/envoyproxy/ai-gateway/internal/dynamicmodule/sdk" |
| 22 | + "github.com/envoyproxy/ai-gateway/internal/filterapi" |
| 23 | + "github.com/envoyproxy/ai-gateway/internal/internalapi" |
| 24 | +) |
| 25 | + |
| 26 | +const routerFilterPointerDynamicMetadataKey = "router_filter_pointer" |
| 27 | + |
| 28 | +type ( |
| 29 | + // routerFilterConfig implements [sdk.HTTPFilterConfig]. |
| 30 | + // |
| 31 | + // This is mostly for debugging purposes, it does not do anything except |
| 32 | + // setting a response header with the version of the dynamic module. |
| 33 | + routerFilterConfig struct { |
| 34 | + fcr **filterapi.RuntimeConfig |
| 35 | + prefixToEndpoint map[string]endpoint |
| 36 | + } |
| 37 | + // routerFilter implements [sdk.HTTPFilter]. |
| 38 | + routerFilter struct { |
| 39 | + routerFilterConfig *routerFilterConfig |
| 40 | + runtimeFilterConfig *filterapi.RuntimeConfig |
| 41 | + endpoint endpoint |
| 42 | + originalHeaders map[string]string |
| 43 | + originalRequestBody any |
| 44 | + originalRequestBodyRaw []byte |
| 45 | + span any |
| 46 | + attemptCount int |
| 47 | + } |
| 48 | +) |
| 49 | + |
| 50 | +// NewRouterFilterConfig creates a new instance of an implementation of [sdk.HTTPFilterConfig] for the router filter. |
| 51 | +func NewRouterFilterConfig(env *Env, fcr **filterapi.RuntimeConfig) sdk.HTTPFilterConfig { |
| 52 | + prefixToEndpoint := map[string]endpoint{ |
| 53 | + path.Join(env.RootPrefix, env.EndpointPrefixes.OpenAI, "/v1/chat/completions"): chatCompletionsEndpoint, |
| 54 | + path.Join(env.RootPrefix, env.EndpointPrefixes.OpenAI, "/v1/completions"): completionsEndpoint, |
| 55 | + path.Join(env.RootPrefix, env.EndpointPrefixes.OpenAI, "/v1/embeddings"): embeddingsEndpoint, |
| 56 | + path.Join(env.RootPrefix, env.EndpointPrefixes.OpenAI, "/v1/images/generations"): imagesGenerationsEndpoint, |
| 57 | + path.Join(env.RootPrefix, env.EndpointPrefixes.Cohere, "/v2/rerank"): rerankEndpoint, |
| 58 | + path.Join(env.RootPrefix, env.EndpointPrefixes.OpenAI, "/v1/models"): modelsEndpoint, |
| 59 | + path.Join(env.RootPrefix, env.EndpointPrefixes.Anthropic, "/v1/messages"): messagesEndpoint, |
| 60 | + } |
| 61 | + return &routerFilterConfig{ |
| 62 | + fcr: fcr, |
| 63 | + prefixToEndpoint: prefixToEndpoint, |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +// NewFilter implements [sdk.HTTPFilterConfig]. |
| 68 | +func (f *routerFilterConfig) NewFilter() sdk.HTTPFilter { |
| 69 | + return &routerFilter{routerFilterConfig: f, runtimeFilterConfig: *f.fcr} |
| 70 | +} |
| 71 | + |
| 72 | +// RequestHeaders implements [sdk.HTTPFilter]. |
| 73 | +func (f *routerFilter) RequestHeaders(e sdk.EnvoyHTTPFilter, _ bool) sdk.RequestHeadersStatus { |
| 74 | + p, _ := e.GetRequestHeader(":path") // The :path pseudo header is always present. |
| 75 | + // Strip query parameters for processor lookup. |
| 76 | + if queryIndex := strings.Index(p, "?"); queryIndex != -1 { |
| 77 | + p = p[:queryIndex] |
| 78 | + } |
| 79 | + ep, ok := f.routerFilterConfig.prefixToEndpoint[p] |
| 80 | + if !ok { |
| 81 | + e.SendLocalReply(404, nil, []byte(fmt.Sprintf("unsupported path: %s", p))) |
| 82 | + return sdk.RequestHeadersStatusStopIteration |
| 83 | + } |
| 84 | + f.endpoint = ep |
| 85 | + if f.endpoint == modelsEndpoint { |
| 86 | + return f.handleModelsEndpoint(e) |
| 87 | + } |
| 88 | + return sdk.RequestHeadersStatusContinue |
| 89 | +} |
| 90 | + |
| 91 | +// RequestBody implements [sdk.HTTPFilter]. |
| 92 | +func (f *routerFilter) RequestBody(e sdk.EnvoyHTTPFilter, endOfStream bool) sdk.RequestBodyStatus { |
| 93 | + if !endOfStream { |
| 94 | + return sdk.RequestBodyStatusStopIterationAndBuffer |
| 95 | + } |
| 96 | + b, ok := e.GetRequestBody() |
| 97 | + if !ok { |
| 98 | + e.SendLocalReply(400, nil, []byte("failed to read request body")) |
| 99 | + return sdk.RequestBodyStatusStopIterationAndBuffer |
| 100 | + } |
| 101 | + raw, err := io.ReadAll(b) |
| 102 | + if err != nil { |
| 103 | + e.SendLocalReply(400, nil, []byte("failed to read request body: "+err.Error())) |
| 104 | + return sdk.RequestBodyStatusStopIterationAndBuffer |
| 105 | + } |
| 106 | + f.originalRequestBodyRaw = raw |
| 107 | + var parsed any |
| 108 | + var modelName string |
| 109 | + switch f.endpoint { |
| 110 | + case chatCompletionsEndpoint: |
| 111 | + parsed, modelName, err = parseBodyWithModel(raw, func(req *openai.ChatCompletionRequest) string { return req.Model }) |
| 112 | + case completionsEndpoint: |
| 113 | + parsed, modelName, err = parseBodyWithModel(raw, func(req *openai.CompletionRequest) string { return req.Model }) |
| 114 | + case embeddingsEndpoint: |
| 115 | + parsed, modelName, err = parseBodyWithModel(raw, func(req *openai.EmbeddingRequest) string { return req.Model }) |
| 116 | + case imagesGenerationsEndpoint: |
| 117 | + parsed, modelName, err = parseBodyWithModel(raw, func(req *openaisdk.ImageGenerateParams) string { return req.Model }) |
| 118 | + case rerankEndpoint: |
| 119 | + parsed, modelName, err = parseBodyWithModel(raw, func(req *cohereschema.RerankV2Request) string { return req.Model }) |
| 120 | + case messagesEndpoint: |
| 121 | + parsed, modelName, err = parseBodyWithModel(raw, func(req *anthropic.MessagesRequest) string { return req.GetModel() }) |
| 122 | + default: |
| 123 | + e.SendLocalReply(500, nil, []byte("BUG: unsupported endpoint at body parsing: "+fmt.Sprintf("%d", f.endpoint))) |
| 124 | + } |
| 125 | + if err != nil { |
| 126 | + e.SendLocalReply(400, nil, []byte("failed to parse request body: "+err.Error())) |
| 127 | + return sdk.RequestBodyStatusStopIterationAndBuffer |
| 128 | + } |
| 129 | + f.originalRequestBody = parsed |
| 130 | + if !e.SetRequestHeader(internalapi.ModelNameHeaderKeyDefault, []byte(modelName)) { |
| 131 | + e.SendLocalReply(500, nil, []byte("failed to set model name header")) |
| 132 | + return sdk.RequestBodyStatusStopIterationAndBuffer |
| 133 | + } |
| 134 | + // Store the pointer to the filter in dynamic metadata for later retrieval in the upstream filter. |
| 135 | + e.SetDynamicMetadataString(internalapi.AIGatewayFilterMetadataNamespace, routerFilterPointerDynamicMetadataKey, |
| 136 | + fmt.Sprintf("%d", uintptr(unsafe.Pointer(f)))) |
| 137 | + |
| 138 | + f.originalHeaders = multiValueHeadersToSingleValue(e.GetRequestHeaders()) |
| 139 | + return sdk.RequestBodyStatusContinue |
| 140 | +} |
| 141 | + |
| 142 | +// ResponseHeaders implements [sdk.HTTPFilter]. |
| 143 | +func (f *routerFilter) ResponseHeaders(sdk.EnvoyHTTPFilter, bool) sdk.ResponseHeadersStatus { |
| 144 | + return sdk.ResponseHeadersStatusContinue |
| 145 | +} |
| 146 | + |
| 147 | +// ResponseBody implements [sdk.HTTPFilter]. |
| 148 | +func (f *routerFilter) ResponseBody(sdk.EnvoyHTTPFilter, bool) sdk.ResponseBodyStatus { |
| 149 | + return sdk.ResponseBodyStatusContinue |
| 150 | +} |
| 151 | + |
| 152 | +// handleModelsEndpoint handles the /v1/models endpoint by returning the list of declared models in the filter configuration. |
| 153 | +// |
| 154 | +// This is called on request headers phase. |
| 155 | +func (f *routerFilter) handleModelsEndpoint(e sdk.EnvoyHTTPFilter) sdk.RequestHeadersStatus { |
| 156 | + config := f.runtimeFilterConfig |
| 157 | + models := openai.ModelList{ |
| 158 | + Object: "list", |
| 159 | + Data: make([]openai.Model, 0, len(config.DeclaredModels)), |
| 160 | + } |
| 161 | + for _, m := range config.DeclaredModels { |
| 162 | + models.Data = append(models.Data, openai.Model{ |
| 163 | + ID: m.Name, |
| 164 | + Object: "model", |
| 165 | + OwnedBy: m.OwnedBy, |
| 166 | + Created: openai.JSONUNIXTime(m.CreatedAt), |
| 167 | + }) |
| 168 | + } |
| 169 | + |
| 170 | + body, _ := json.Marshal(models) |
| 171 | + e.SendLocalReply(200, [][2]string{ |
| 172 | + {"content-type", "application/json"}, |
| 173 | + }, body) |
| 174 | + return sdk.RequestHeadersStatusStopIteration |
| 175 | +} |
| 176 | + |
| 177 | +func parseBodyWithModel[T any](body []byte, modelExtractFn func(req *T) string) (interface{}, string, error) { |
| 178 | + var req T |
| 179 | + if err := json.Unmarshal(body, &req); err != nil { |
| 180 | + return nil, "", fmt.Errorf("failed to unmarshal body: %w", err) |
| 181 | + } |
| 182 | + return req, modelExtractFn(&req), nil |
| 183 | +} |
| 184 | + |
| 185 | +// multiValueHeadersToSingleValue converts a map of headers with multiple values to a map of headers with single values by taking the first value for each header. |
| 186 | +// |
| 187 | +// TODO: this is purely for feature parity with the old filter where we ignore the case of multiple header values. |
| 188 | +func multiValueHeadersToSingleValue(headers map[string][]string) map[string]string { |
| 189 | + singleValueHeaders := make(map[string]string, len(headers)) |
| 190 | + for k, v := range headers { |
| 191 | + singleValueHeaders[k] = v[0] |
| 192 | + } |
| 193 | + return singleValueHeaders |
| 194 | +} |
0 commit comments