Skip to content

Commit 95fce20

Browse files
committed
init dynamic modules
Signed-off-by: Takeshi Yoneda <[email protected]>
1 parent 288c118 commit 95fce20

File tree

8 files changed

+1193
-0
lines changed

8 files changed

+1193
-0
lines changed

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ build.%: ## Build a binary for the given command under the internal/cmd director
270270
build: ## Build all binaries under cmd/ directory.
271271
@$(foreach COMMAND_NAME,$(COMMANDS),$(MAKE) build.$(COMMAND_NAME);)
272272

273+
# This builds the dynamic module filter for Envoy. This is the shared library that can be loaded by Envoy to run the AI Gateway filter.
274+
.PHONE: build-dm
275+
build-dm: ## Build the dynamic module for Envoy.
276+
CGO_ENABLED=1 go build -tags "envoy_1.36" -buildmode=c-shared -o $(OUTPUT_DIR)/libaigateway.so ./internal/dynamic_module
277+
273278
# This builds the docker images for the controller, extproc and testupstream for the e2e tests.
274279
.PHONY: build-e2e
275280
build-e2e: ## Build the docker images for the controller, extproc and testupstream for the e2e tests.

internal/dynamic_module/main.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log/slog"
7+
"os"
8+
"time"
9+
10+
"github.com/envoyproxy/ai-gateway/internal/backendauth"
11+
"github.com/envoyproxy/ai-gateway/internal/dynamic_module/sdk"
12+
"github.com/envoyproxy/ai-gateway/internal/filterapi"
13+
)
14+
15+
func main() {} // This must be present to make a shared library.
16+
17+
// Set the envoy.NewHTTPFilter function to create a new http filter.
18+
func init() {
19+
sdk.NewHTTPFilterConfig = newHTTPFilterConfig
20+
ec := newEnvConfig()
21+
22+
// TODO: use a writer implemented with the Logger ABI of Envoy.
23+
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
24+
Level: slog.LevelDebug, // Adjust log level from environment variable if needed.
25+
}))
26+
if err := filterapi.StartConfigWatcher(context.Background(), ec.filterConfigPath, filterConfigReceiver, logger, time.Second*5); err != nil {
27+
panic("failed to start filter config watcher: " + err.Error())
28+
}
29+
}
30+
31+
// newHTTPFilter creates a new http filter based on the config.
32+
//
33+
// `config` is the configuration string that is specified in the Envoy configuration.
34+
func newHTTPFilterConfig(name string, _ []byte) sdk.HTTPFilterConfig {
35+
switch name {
36+
case "ai_gateway.router":
37+
return newRouterFilterConfig(filterConfigReceiver)
38+
case "ai_gateway.upstream":
39+
return &upstreamFilterConfig{}
40+
default:
41+
panic("unknown filter: " + name)
42+
}
43+
}
44+
45+
var filterConfigReceiver *filterConfigReceiverImpl
46+
47+
// filterConfigReceiverImpl implements [filterapi.ConfigReceiver] to load filter configuration.
48+
type filterConfigReceiverImpl struct {
49+
fc *filterapi.RuntimeConfig
50+
}
51+
52+
// LoadConfig implements [filterapi.ConfigReceiver.LoadConfig].
53+
func (l *filterConfigReceiverImpl) LoadConfig(ctx context.Context, config *filterapi.Config) error {
54+
newConfig, err := filterapi.NewRuntimeConfig(ctx, config, backendauth.NewHandler)
55+
if err != nil {
56+
return fmt.Errorf("cannot create runtime filter config: %w", err)
57+
}
58+
l.fc = newConfig // This is racy but we don't care.
59+
return nil
60+
}
61+
62+
type envConfig struct {
63+
filterConfigPath string
64+
// TODO: log level.
65+
}
66+
67+
func newEnvConfig() *envConfig {
68+
return &envConfig{
69+
filterConfigPath: os.Getenv("AI_GATEWAY_DYNAMIC_MODULE_FILTER_CONFIG_PATH"),
70+
}
71+
}
72+
73+
// endpoint represents the type of the endpoint that the request is targeting.
74+
type endpoint int
75+
76+
const (
77+
// chatCompletionsEndpoint represents the /v1/chat/completions endpoint.
78+
chatCompletionsEndpoint endpoint = iota
79+
// completionsEndpoint represents the /v1/completions endpoint.
80+
completionsEndpoint
81+
// embeddingsEndpoint represents the /v1/embeddings endpoint.
82+
embeddingsEndpoint
83+
// imagesGenerationsEndpoint represents the /v1/images/generations endpoint.
84+
imagesGenerationsEndpoint
85+
// rerankEndpoint represents the /v2/rerank endpoint of cohere.
86+
rerankEndpoint
87+
// messagesEndpoint represents the /v1/messages endpoint of anthropic.
88+
messagesEndpoint
89+
// modelsEndpoint represents the /v1/models endpoint.
90+
modelsEndpoint
91+
)
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
package main
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"strings"
8+
"unsafe"
9+
10+
"github.com/envoyproxy/ai-gateway/internal/apischema/anthropic"
11+
cohereschema "github.com/envoyproxy/ai-gateway/internal/apischema/cohere"
12+
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
13+
"github.com/envoyproxy/ai-gateway/internal/dynamic_module/sdk"
14+
"github.com/envoyproxy/ai-gateway/internal/internalapi"
15+
openaisdk "github.com/openai/openai-go/v2"
16+
)
17+
18+
const routerFilterPointerDynamicMetadataKey = "router_filter_pointer"
19+
20+
type (
21+
// routerFilterConfig implements [sdk.HTTPFilterConfig].
22+
//
23+
// This is mostly for debugging purposes, it does not do anything except
24+
// setting a response header with the version of the dynamic module.
25+
routerFilterConfig struct {
26+
fcr *filterConfigReceiverImpl
27+
models openai.ModelList
28+
}
29+
// routerFilter implements [sdk.HTTPFilter].
30+
routerFilter struct {
31+
fc *routerFilterConfig
32+
endpoint endpoint
33+
originalRequestBody interface{}
34+
originalRequestBodyRaw []byte
35+
}
36+
37+
requestBodyParserFn func(body []byte) (parsed interface{}, modelName string, err error)
38+
)
39+
40+
func newRouterFilterConfig(fcr *filterConfigReceiverImpl) *routerFilterConfig {
41+
config := fcr.fc
42+
models := openai.ModelList{
43+
Object: "list",
44+
Data: make([]openai.Model, 0, len(config.DeclaredModels)),
45+
}
46+
for _, m := range config.DeclaredModels {
47+
models.Data = append(models.Data, openai.Model{
48+
ID: m.Name,
49+
Object: "model",
50+
OwnedBy: m.OwnedBy,
51+
Created: openai.JSONUNIXTime(m.CreatedAt),
52+
})
53+
}
54+
return &routerFilterConfig{fcr: fcr, models: models}
55+
}
56+
57+
// NewFilter implements [sdk.HTTPFilterConfig].
58+
func (f *routerFilterConfig) NewFilter() sdk.HTTPFilter {
59+
return &routerFilter{fc: f}
60+
}
61+
62+
// RequestHeaders implements [sdk.HTTPFilter].
63+
func (f *routerFilter) RequestHeaders(e sdk.EnvoyHTTPFilter, _ bool) sdk.RequestHeadersStatus {
64+
p, _ := e.GetRequestHeader(":path") // The :path pseudo header is always present.
65+
// Strip query parameters for processor lookup.
66+
if queryIndex := strings.Index(p, "?"); queryIndex != -1 {
67+
p = p[:queryIndex]
68+
}
69+
// TODO: prefix config.
70+
switch p {
71+
case "/v1/chat/completions":
72+
f.endpoint = chatCompletionsEndpoint
73+
return sdk.RequestHeadersStatusContinue
74+
case "/v1/completions":
75+
f.endpoint = completionsEndpoint
76+
return sdk.RequestHeadersStatusContinue
77+
case "/v1/embeddings":
78+
f.endpoint = embeddingsEndpoint
79+
return sdk.RequestHeadersStatusContinue
80+
case "/v1/images/generations":
81+
f.endpoint = imagesGenerationsEndpoint
82+
return sdk.RequestHeadersStatusContinue
83+
case "/cohere/v2/rerank":
84+
f.endpoint = rerankEndpoint
85+
return sdk.RequestHeadersStatusContinue
86+
case "/anthropic/v1/messages":
87+
f.endpoint = messagesEndpoint
88+
return sdk.RequestHeadersStatusContinue
89+
case "/v1/models":
90+
return f.handleModelsEndpoint(e)
91+
default:
92+
e.SendLocalReply(404, nil, []byte(fmt.Sprintf("unsupported path: %s", p)))
93+
return sdk.RequestHeadersStatusStopIteration
94+
}
95+
}
96+
97+
// RequestBody implements [sdk.HTTPFilter].
98+
func (f *routerFilter) RequestBody(e sdk.EnvoyHTTPFilter, endOfStream bool) sdk.RequestBodyStatus {
99+
if !endOfStream {
100+
return sdk.RequestBodyStatusStopIterationAndBuffer
101+
}
102+
b, ok := e.GetRequestBody()
103+
if !ok {
104+
e.SendLocalReply(400, nil, []byte("failed to read request body"))
105+
return sdk.RequestBodyStatusStopIterationAndBuffer
106+
}
107+
raw, err := io.ReadAll(b)
108+
if err != nil {
109+
e.SendLocalReply(400, nil, []byte("failed to read request body: "+err.Error()))
110+
return sdk.RequestBodyStatusStopIterationAndBuffer
111+
}
112+
f.originalRequestBodyRaw = raw
113+
var parserFn requestBodyParserFn
114+
switch f.endpoint {
115+
case chatCompletionsEndpoint:
116+
parserFn = chatCompletionsBodyParser
117+
case completionsEndpoint:
118+
parserFn = completionsBodyParser
119+
case embeddingsEndpoint:
120+
parserFn = embeddingsBodyParser
121+
case imagesGenerationsEndpoint:
122+
parserFn = imagesGenerationsBodyParser
123+
case rerankEndpoint:
124+
parserFn = rerankBodyParser
125+
case messagesEndpoint:
126+
parserFn = messagesBodyParser
127+
default:
128+
e.SendLocalReply(500, nil, []byte("BUG: unsupported endpoint at body parsing: "+fmt.Sprintf("%d", f.endpoint)))
129+
}
130+
parsed, modelName, err := parserFn(raw)
131+
if err != nil {
132+
e.SendLocalReply(400, nil, []byte("failed to parse request body: "+err.Error()))
133+
return sdk.RequestBodyStatusStopIterationAndBuffer
134+
}
135+
f.originalRequestBody = parsed
136+
if !e.SetRequestHeader(internalapi.ModelNameHeaderKeyDefault, []byte(modelName)) {
137+
e.SendLocalReply(500, nil, []byte("failed to set model name header"))
138+
return sdk.RequestBodyStatusStopIterationAndBuffer
139+
}
140+
// Store the pointer to the filter in dynamic metadata for later retrieval in the upstream filter.
141+
e.SetDynamicMetadataString(internalapi.AIGatewayFilterMetadataNamespace, routerFilterPointerDynamicMetadataKey,
142+
fmt.Sprintf("%d", uintptr(unsafe.Pointer(f))))
143+
return sdk.RequestBodyStatusContinue
144+
}
145+
146+
// ResponseHeaders implements [sdk.HTTPFilter].
147+
func (f *routerFilter) ResponseHeaders(sdk.EnvoyHTTPFilter, bool) sdk.ResponseHeadersStatus {
148+
return sdk.ResponseHeadersStatusContinue
149+
}
150+
151+
// ResponseBody implements [sdk.HTTPFilter].
152+
func (f *routerFilter) ResponseBody(sdk.EnvoyHTTPFilter, bool) sdk.ResponseBodyStatus {
153+
return sdk.ResponseBodyStatusContinue
154+
}
155+
156+
// handleModelsEndpoint handles the /v1/models endpoint by returning the list of declared models in the filter configuration.
157+
//
158+
// This is called on request headers phase.
159+
func (f *routerFilter) handleModelsEndpoint(e sdk.EnvoyHTTPFilter) sdk.RequestHeadersStatus {
160+
body, _ := json.Marshal(f.fc.models)
161+
e.SendLocalReply(200, [][2]string{
162+
{"content-type", "application/json"},
163+
}, body)
164+
return sdk.RequestHeadersStatusStopIteration
165+
}
166+
167+
func chatCompletionsBodyParser(body []byte) (interface{}, string, error) {
168+
var req openai.ChatCompletionRequest
169+
if err := json.Unmarshal(body, &req); err != nil {
170+
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
171+
}
172+
return req, req.Model, nil
173+
}
174+
175+
func completionsBodyParser(body []byte) (interface{}, string, error) {
176+
var req openai.CompletionRequest
177+
if err := json.Unmarshal(body, &req); err != nil {
178+
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
179+
}
180+
return req, req.Model, nil
181+
}
182+
183+
func embeddingsBodyParser(body []byte) (interface{}, string, error) {
184+
var req openai.EmbeddingRequest
185+
if err := json.Unmarshal(body, &req); err != nil {
186+
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
187+
}
188+
return req, req.Model, nil
189+
}
190+
191+
func imagesGenerationsBodyParser(body []byte) (interface{}, string, error) {
192+
var req openaisdk.ImageGenerateParams
193+
if err := json.Unmarshal(body, &req); err != nil {
194+
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
195+
}
196+
return req, req.Model, nil
197+
}
198+
199+
func rerankBodyParser(body []byte) (interface{}, string, error) {
200+
var req cohereschema.RerankV2Request
201+
if err := json.Unmarshal(body, &req); err != nil {
202+
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
203+
}
204+
return req, req.Model, nil
205+
}
206+
207+
func messagesBodyParser(body []byte) (interface{}, string, error) {
208+
var anthropicReq anthropic.MessagesRequest
209+
if err := json.Unmarshal(body, &anthropicReq); err != nil {
210+
return nil, "", fmt.Errorf("failed to unmarshal body: %w", err)
211+
}
212+
return anthropicReq, anthropicReq.GetModel(), nil
213+
}

0 commit comments

Comments
 (0)