Skip to content

Commit 1249077

Browse files
authored
Merge branch 'main' into gcp-anthropic
2 parents 30c3e4d + 6ae3615 commit 1249077

26 files changed

+216
-75
lines changed

cmd/aigw/main.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,15 @@ type (
4141
}
4242
// cmdRun corresponds to `aigw run` command.
4343
cmdRun struct {
44-
Debug bool `help:"Enable debug logging emitted to stderr."`
45-
Path string `arg:"" name:"path" optional:"" help:"Path to the AI Gateway configuration yaml file. Defaults to $AIGW_CONFIG_HOME/config.yaml if exists, otherwise optional when at least OPENAI_API_KEY, AZURE_OPENAI_API_KEY or ANTHROPIC_API_KEY is set." type:"path"`
46-
AdminPort int `help:"HTTP port for the admin server (serves /metrics and /health endpoints)." default:"1064"`
47-
McpConfig string `name:"mcp-config" help:"Path to MCP servers configuration file." type:"path"`
48-
McpJSON string `name:"mcp-json" help:"JSON string of MCP servers configuration."`
49-
RunID string `name:"run-id" env:"AIGW_RUN_ID" help:"Run identifier for this invocation. Defaults to timestamp-based ID or $AIGW_RUN_ID. Use '0' for Docker/Kubernetes."`
44+
Debug bool `help:"Enable debug logging emitted to stderr."`
45+
Path string `arg:"" name:"path" optional:"" help:"Path to the AI Gateway configuration yaml file. Defaults to $AIGW_CONFIG_HOME/config.yaml if exists, otherwise optional when at least OPENAI_API_KEY, AZURE_OPENAI_API_KEY or ANTHROPIC_API_KEY is set." type:"path"`
46+
AdminPort int `help:"HTTP port for the admin server (serves /metrics and /health endpoints)." default:"1064"`
47+
McpConfig string `name:"mcp-config" help:"Path to MCP servers configuration file." type:"path"`
48+
McpJSON string `name:"mcp-json" help:"JSON string of MCP servers configuration."`
49+
RunID string `name:"run-id" env:"AIGW_RUN_ID" help:"Run identifier for this invocation. Defaults to timestamp-based ID or $AIGW_RUN_ID. Use '0' for Docker/Kubernetes."`
50+
51+
MCPSessionEncryptionIterations int `name:"mcp-session-encryption-iterations" help:"Number of iterations for MCP session encryption key derivation." default:"100000"`
52+
5053
mcpConfig *autoconfig.MCPServers `kong:"-"` // Internal field: normalized MCP JSON data
5154
dirs *xdg.Directories `kong:"-"` // Internal field: XDG directories, set by BeforeApply
5255
runOpts *runOpts `kong:"-"` // Internal field: run options, set by Validate

cmd/aigw/main_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ Flags:
136136
--run-id=STRING Run identifier for this invocation. Defaults to
137137
timestamp-based ID or $AIGW_RUN_ID. Use '0' for
138138
Docker/Kubernetes ($AIGW_RUN_ID).
139+
--mcp-session-encryption-iterations=100000
140+
Number of iterations for MCP session encryption
141+
key derivation.
139142
`,
140143
expPanicCode: ptr.To(0),
141144
},

cmd/aigw/run.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ type runCmdContext struct {
7878
// fakeClientSet is the fake client set for the k8s resources. The objects are written to this client set and updated
7979
// during the translation.
8080
fakeClientSet *fake.Clientset
81+
// mcpSessionEncryptionIterations is the number of iterations for MCP session encryption key derivation.
82+
mcpSessionEncryptionIterations int
8183
}
8284

8385
// run starts the AI Gateway locally for a given configuration.
@@ -361,6 +363,7 @@ func (runCtx *runCmdContext) mustStartExtProc(
361363
"--extProcAddr", fmt.Sprintf("unix://%s", runCtx.udsPath),
362364
"--adminPort", fmt.Sprintf("%d", runCtx.adminPort),
363365
"--mcpAddr", ":" + strconv.Itoa(internalapi.MCPProxyPort),
366+
"--mcpSessionEncryptionIterations", strconv.Itoa(runCtx.mcpSessionEncryptionIterations),
364367
}
365368
if runCtx.isDebug {
366369
args = append(args, "--logLevel", "debug")

cmd/extproc/mainlib/main.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,17 +250,17 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) {
250250
return err
251251
}
252252

253-
server, err := extproc.NewServer(l, tracing)
253+
server, err := extproc.NewServer(l)
254254
if err != nil {
255255
return fmt.Errorf("failed to create external processor server: %w", err)
256256
}
257-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/chat/completions"), extproc.ChatCompletionProcessorFactory(chatCompletionMetricsFactory))
258-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/completions"), extproc.CompletionsProcessorFactory(completionMetricsFactory))
259-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/embeddings"), extproc.EmbeddingsProcessorFactory(embeddingsMetricsFactory))
260-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/images/generations"), extproc.ImageGenerationProcessorFactory(imageGenerationMetricsFactory))
261-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Cohere, "/v2/rerank"), extproc.RerankProcessorFactory(rerankMetricsFactory))
257+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/chat/completions"), extproc.ChatCompletionProcessorFactory(chatCompletionMetricsFactory, tracing.ChatCompletionTracer()))
258+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/completions"), extproc.CompletionsProcessorFactory(completionMetricsFactory, tracing.CompletionTracer()))
259+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/embeddings"), extproc.EmbeddingsProcessorFactory(embeddingsMetricsFactory, tracing.EmbeddingsTracer()))
260+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/images/generations"), extproc.ImageGenerationProcessorFactory(imageGenerationMetricsFactory, tracing.ImageGenerationTracer()))
261+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Cohere, "/v2/rerank"), extproc.RerankProcessorFactory(rerankMetricsFactory, tracing.RerankTracer()))
262262
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/models"), extproc.NewModelsProcessor)
263-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Anthropic, "/v1/messages"), extproc.MessagesProcessorFactory(messagesMetricsFactory))
263+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Anthropic, "/v1/messages"), extproc.MessagesProcessorFactory(messagesMetricsFactory, tracing.MessageTracer()))
264264

265265
if watchErr := filterapi.StartConfigWatcher(ctx, flags.configPath, server, l, time.Second*5); watchErr != nil {
266266
return fmt.Errorf("failed to start config watcher: %w", watchErr)

internal/extproc/chatcompletion_processor.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ import (
3131
)
3232

3333
// ChatCompletionProcessorFactory returns a factory method to instantiate the chat completion processor.
34-
func ChatCompletionProcessorFactory(f metrics.Factory) ProcessorFactory {
35-
return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) {
34+
func ChatCompletionProcessorFactory(f metrics.Factory, tracer tracing.ChatCompletionTracer) ProcessorFactory {
35+
return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, isUpstreamFilter bool) (Processor, error) {
3636
logger = logger.With("processor", "chat-completion", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter))
3737
if !isUpstreamFilter {
3838
return &chatCompletionProcessorRouterFilter{
3939
config: config,
40-
tracer: tracing.ChatCompletionTracer(),
40+
tracer: tracer,
4141
requestHeaders: requestHeaders,
4242
logger: logger,
4343
}, nil

internal/extproc/chatcompletion_processor_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ import (
3333
func TestChatCompletion_Schema(t *testing.T) {
3434
t.Run("supported openai / on route", func(t *testing.T) {
3535
cfg := &filterapi.RuntimeConfig{}
36-
routeFilter, err := ChatCompletionProcessorFactory(nil)(cfg, nil, slog.Default(), tracing.NoopTracing{}, false)
36+
routeFilter, err := ChatCompletionProcessorFactory(nil, tracing.NoopChatCompletionTracer{})(cfg, nil, slog.Default(), false)
3737
require.NoError(t, err)
3838
require.NotNil(t, routeFilter)
3939
require.IsType(t, &chatCompletionProcessorRouterFilter{}, routeFilter)
4040
})
4141
t.Run("supported openai / on upstream", func(t *testing.T) {
4242
cfg := &filterapi.RuntimeConfig{}
43-
routeFilter, err := ChatCompletionProcessorFactory(&mockMetricsFactory{})(cfg, nil, slog.Default(), tracing.NoopTracing{}, true)
43+
routeFilter, err := ChatCompletionProcessorFactory(&mockMetricsFactory{}, tracing.NoopChatCompletionTracer{})(cfg, nil, slog.Default(), true)
4444
require.NoError(t, err)
4545
require.NotNil(t, routeFilter)
4646
require.IsType(t, &chatCompletionProcessorUpstreamFilter{}, routeFilter)

internal/extproc/completions_processor.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ import (
2929
)
3030

3131
// CompletionsProcessorFactory returns a factory method to instantiate the completions processor.
32-
func CompletionsProcessorFactory(f metrics.Factory) ProcessorFactory {
33-
return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) {
32+
func CompletionsProcessorFactory(f metrics.Factory, tracer tracing.CompletionTracer) ProcessorFactory {
33+
return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, isUpstreamFilter bool) (Processor, error) {
3434
logger = logger.With("processor", "completions", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter))
3535
if !isUpstreamFilter {
3636
return &completionsProcessorRouterFilter{
3737
config: config,
38-
tracer: tracing.CompletionTracer(),
38+
tracer: tracer,
3939
requestHeaders: requestHeaders,
4040
logger: logger,
4141
}, nil

internal/extproc/completions_processor_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func TestCompletions_Schema(t *testing.T) {
4949
for _, tt := range tests {
5050
t.Run(tt.name, func(t *testing.T) {
5151
cfg := &filterapi.RuntimeConfig{}
52-
filter, err := CompletionsProcessorFactory(&mockMetricsFactory{})(cfg, nil, slog.Default(), tracing.NoopTracing{}, tt.onUpstream)
52+
filter, err := CompletionsProcessorFactory(&mockMetricsFactory{}, tracing.NoopCompletionTracer{})(cfg, nil, slog.Default(), tt.onUpstream)
5353
require.NoError(t, err)
5454
require.NotNil(t, filter)
5555
require.IsType(t, tt.expectedType, filter)

internal/extproc/embeddings_processor.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ import (
2828
)
2929

3030
// EmbeddingsProcessorFactory returns a factory method to instantiate the embeddings processor.
31-
func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
32-
return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) {
31+
func EmbeddingsProcessorFactory(f metrics.Factory, tracer tracing.EmbeddingsTracer) ProcessorFactory {
32+
return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, isUpstreamFilter bool) (Processor, error) {
3333
logger = logger.With("processor", "embeddings", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter))
3434
if !isUpstreamFilter {
3535
return &embeddingsProcessorRouterFilter{
3636
config: config,
37-
tracer: tracing.EmbeddingsTracer(),
37+
tracer: tracer,
3838
requestHeaders: requestHeaders,
3939
logger: logger,
4040
}, nil

internal/extproc/embeddings_processor_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ import (
2929
func TestEmbeddings_Schema(t *testing.T) {
3030
t.Run("supported openai / on route", func(t *testing.T) {
3131
cfg := &filterapi.RuntimeConfig{}
32-
routeFilter, err := EmbeddingsProcessorFactory(nil)(cfg, nil, slog.Default(), tracing.NoopTracing{}, false)
32+
routeFilter, err := EmbeddingsProcessorFactory(nil, tracing.NoopEmbeddingsTracer{})(cfg, nil, slog.Default(), false)
3333
require.NoError(t, err)
3434
require.NotNil(t, routeFilter)
3535
require.IsType(t, &embeddingsProcessorRouterFilter{}, routeFilter)
3636
})
3737
t.Run("supported openai / on upstream", func(t *testing.T) {
3838
cfg := &filterapi.RuntimeConfig{}
39-
routeFilter, err := EmbeddingsProcessorFactory(&mockMetricsFactory{})(cfg, nil, slog.Default(), tracing.NoopTracing{}, true)
39+
routeFilter, err := EmbeddingsProcessorFactory(&mockMetricsFactory{}, tracing.NoopEmbeddingsTracer{})(cfg, nil, slog.Default(), true)
4040
require.NoError(t, err)
4141
require.NotNil(t, routeFilter)
4242
require.IsType(t, &embeddingsProcessorUpstreamFilter{}, routeFilter)

0 commit comments

Comments
 (0)