diff --git a/internal/mcpproxy/config.go b/internal/mcpproxy/config.go new file mode 100644 index 0000000000..762a6538e0 --- /dev/null +++ b/internal/mcpproxy/config.go @@ -0,0 +1,195 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package mcpproxy + +import ( + "context" + "fmt" + "maps" + "regexp" + "slices" + "strings" + "sync" + + "github.com/envoyproxy/ai-gateway/internal/filterapi" +) + +type ( + // ProxyConfig holds the main MCP proxy configuration. + ProxyConfig struct { + *mcpProxyConfig + toolChangeSignaler changeSignaler // signals tool changes to active sessions. + } + + mcpProxyConfig struct { + backendListenerAddr string + routes map[filterapi.MCPRouteName]*mcpProxyConfigRoute // route name -> backends of that route. + } + + mcpProxyConfigRoute struct { + backends map[filterapi.MCPBackendName]filterapi.MCPBackend + toolSelectors map[filterapi.MCPBackendName]*toolSelector + } + + // toolSelector filters tools using include patterns with exact matches or regular expressions. + toolSelector struct { + include map[string]struct{} + includeRegexps []*regexp.Regexp + } + + // changeSignaler is an interface for signaling configuration changes to multiple + // watchers. + changeSignaler interface { + // Watch returns a channel that is closed then the configuration changes. + // The channel should be obtained by calling this method every time when used in a loop, + // because it will be closed and recreated after each signal is sent. + Watch() <-chan struct{} + // Signal all watchers that the configuration has changed. + Signal() + } + + multiWatcherSignaler struct { + mu sync.Mutex + notify chan struct{} + } +) + +func (m *mcpProxyConfig) sameTools(other *mcpProxyConfig) bool { + if m == nil || other == nil { + return m == other + } + return maps.EqualFunc(m.routes, other.routes, func(a, b *mcpProxyConfigRoute) bool { + return a.sameTools(b) + }) +} + +func (m *mcpProxyConfigRoute) sameTools(other *mcpProxyConfigRoute) bool { + if m == nil || other == nil { + return m == other + } + if !equalKeys(m.backends, other.backends) { + return false + } + return maps.EqualFunc(m.toolSelectors, other.toolSelectors, func(a, b *toolSelector) bool { + return a.sameTools(b) + }) +} + +var sortRegexpAsString = func(a, b *regexp.Regexp) int { return strings.Compare(a.String(), b.String()) } + +func equalKeys[K comparable, V any](m1, m2 map[K]V) bool { + return maps.EqualFunc(m1, m2, func(_, _ V) bool { return true }) +} + +func (t *toolSelector) sameTools(other *toolSelector) bool { + if t == nil || other == nil { + return t == other + } + if !equalKeys(t.include, other.include) { + return false + } + slices.SortFunc(t.includeRegexps, sortRegexpAsString) + slices.SortFunc(other.includeRegexps, sortRegexpAsString) + return slices.EqualFunc(t.includeRegexps, other.includeRegexps, + func(a, b *regexp.Regexp) bool { + return a.String() == b.String() + }) +} + +func (t *toolSelector) allows(tool string) bool { + // Check include filters - if no filter, allow all; if filter exists, allow only matches + if len(t.include) > 0 { + _, ok := t.include[tool] + return ok + } + if len(t.includeRegexps) > 0 { + for _, re := range t.includeRegexps { + if re.MatchString(tool) { + return true + } + } + return false + } + // No filters, allow all + return true +} + +// LoadConfig implements [extproc.ConfigReceiver.LoadConfig] which will be called +// when the configuration is updated on the file system. +func (p *ProxyConfig) LoadConfig(_ context.Context, config *filterapi.Config) error { + newConfig := &mcpProxyConfig{} + mcpConfig := config.MCPConfig + if config.MCPConfig == nil { + return nil + } + + // Talk to the backend MCP listener on the local Envoy instance. + newConfig.backendListenerAddr = mcpConfig.BackendListenerAddr + + // Build a map of routes to backends. + // Each route has its own set of backends. For a given downstream request, + // the MCP proxy initializes sessions only with the backends tied to that route. + newConfig.routes = make(map[filterapi.MCPRouteName]*mcpProxyConfigRoute, len(mcpConfig.Routes)) + + for _, route := range mcpConfig.Routes { + r := &mcpProxyConfigRoute{ + backends: make(map[filterapi.MCPBackendName]filterapi.MCPBackend, len(route.Backends)), + toolSelectors: make(map[filterapi.MCPBackendName]*toolSelector, len(route.Backends)), + } + for _, backend := range route.Backends { + r.backends[backend.Name] = backend + if s := backend.ToolSelector; s != nil { + ts := &toolSelector{ + include: make(map[string]struct{}), + } + for _, tool := range s.Include { + ts.include[tool] = struct{}{} + } + for _, expr := range s.IncludeRegex { + re, err := regexp.Compile(expr) + if err != nil { + return fmt.Errorf("failed to compile include regex %q for backend %q in route %q: %w", expr, backend.Name, route.Name, err) + } + ts.includeRegexps = append(ts.includeRegexps, re) + } + r.toolSelectors[backend.Name] = ts + } + } + newConfig.routes[route.Name] = r + } + + toolsChanged := !p.sameTools(newConfig) + p.mcpProxyConfig = newConfig // This is racy, but we don't care. + if toolsChanged { + p.toolChangeSignaler.Signal() + } + + return nil +} + +// newMultiWatcherSignaler creates a new multi-watcher signaler. +func newMultiWatcherSignaler() *multiWatcherSignaler { + return &multiWatcherSignaler{ + notify: make(chan struct{}), + } +} + +// Watch returns a channel that is closed then the configuration changes. +// The channel should be obtained by calling this method every time when used in a loop, +// because it will be closed and recreated after each signal is sent. +func (m *multiWatcherSignaler) Watch() <-chan struct{} { + m.mu.Lock() + defer m.mu.Unlock() + return m.notify +} + +// Signal notifies all watchers of a configuration change. +func (m *multiWatcherSignaler) Signal() { + m.mu.Lock() + defer m.mu.Unlock() + close(m.notify) // Wake everyone + m.notify = make(chan struct{}) // Create a new channel for future updates +} diff --git a/internal/mcpproxy/config_test.go b/internal/mcpproxy/config_test.go new file mode 100644 index 0000000000..bbb1dbd0db --- /dev/null +++ b/internal/mcpproxy/config_test.go @@ -0,0 +1,386 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package mcpproxy + +import ( + "log/slog" + "regexp" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/filterapi" +) + +func Test_toolSelector_Allows(t *testing.T) { + reBa := regexp.MustCompile("^ba.*") + tests := []struct { + name string + selector toolSelector + tools []string + want []bool + }{ + { + name: "no rules allows all", + selector: toolSelector{}, + tools: []string{"foo", "bar"}, + want: []bool{true, true}, + }, + { + name: "include specific tool", + selector: toolSelector{include: map[string]struct{}{"foo": {}}}, + tools: []string{"foo", "bar"}, + want: []bool{true, false}, + }, + { + name: "include regexp", + selector: toolSelector{includeRegexps: []*regexp.Regexp{reBa}}, + tools: []string{"bar", "foo"}, + want: []bool{true, false}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for i, tool := range tt.tools { + got := tt.selector.allows(tool) + require.Equalf(t, tt.want[i], got, "tool: %s", tool) + } + }) + } +} + +func TestLoadConfig_NilMCPConfig(t *testing.T) { + proxy, _, err := NewMCPProxy(slog.Default(), stubMetrics{}, noopTracer, NewPBKDF2AesGcmSessionCrypto("test", 100)) + require.NoError(t, err) + + config := &filterapi.Config{MCPConfig: nil} + + err = proxy.LoadConfig(t.Context(), config) + require.NoError(t, err) +} + +func TestLoadConfig_BasicConfiguration(t *testing.T) { + proxy := &ProxyConfig{ + mcpProxyConfig: &mcpProxyConfig{}, + toolChangeSignaler: newMultiWatcherSignaler(), + } + + config := &filterapi.Config{ + MCPConfig: &filterapi.MCPConfig{ + BackendListenerAddr: "http://localhost:8080", + Routes: []filterapi.MCPRoute{ + { + Name: "route1", + Backends: []filterapi.MCPBackend{ + {Name: "backend1", Path: "/mcp1"}, + { + Name: "backend2", Path: "/mcp2", + ToolSelector: &filterapi.MCPToolSelector{ + Include: []string{"tool1", "tool2"}, + IncludeRegex: []string{"^test.*"}, + }, + }, + }, + }, + { + Name: "route2", + Backends: []filterapi.MCPBackend{ + {Name: "backend3", Path: "/mcp3"}, + {Name: "backend4", Path: "/mcp4"}, + }, + }, + }, + }, + } + + err := proxy.LoadConfig(t.Context(), config) + require.NoError(t, err) + require.Equal(t, "http://localhost:8080", proxy.backendListenerAddr) + require.Len(t, proxy.routes, 2) + require.Contains(t, proxy.routes, filterapi.MCPRouteName("route1")) + require.Contains(t, proxy.routes, filterapi.MCPRouteName("route2")) + require.Len(t, proxy.routes["route1"].backends, 2) + require.Len(t, proxy.routes["route2"].backends, 2) + require.Contains(t, proxy.routes["route1"].backends, filterapi.MCPBackendName("backend1")) + require.Contains(t, proxy.routes["route1"].backends, filterapi.MCPBackendName("backend2")) + require.Contains(t, proxy.routes["route2"].backends, filterapi.MCPBackendName("backend3")) + require.Contains(t, proxy.routes["route2"].backends, filterapi.MCPBackendName("backend4")) + selector := proxy.routes["route1"].toolSelectors["backend2"] + require.NotNil(t, selector) + require.Contains(t, selector.include, "tool1") + require.Contains(t, selector.include, "tool2") + require.Len(t, selector.includeRegexps, 1) + require.True(t, selector.includeRegexps[0].MatchString("test123")) + require.False(t, selector.includeRegexps[0].MatchString("other")) +} + +func TestLoadConfig_ToolsChangedNotification(t *testing.T) { + toolChangeSignaler := newMultiWatcherSignaler() + watcher := toolChangeSignaler.Watch() + + // Initialize proxy with initial configuration directly + proxy := &ProxyConfig{ + mcpProxyConfig: &mcpProxyConfig{ + backendListenerAddr: "http://localhost:8080", + routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{ + "route1": { + backends: map[filterapi.MCPBackendName]filterapi.MCPBackend{ + "backend1": {Name: "backend1", Path: "/mcp1"}, + }, + toolSelectors: map[filterapi.MCPBackendName]*toolSelector{}, + }, + }, + }, + toolChangeSignaler: toolChangeSignaler, + } + + // Update with a different backend (tools changed) + config := &filterapi.Config{ + MCPConfig: &filterapi.MCPConfig{ + BackendListenerAddr: "http://localhost:8080", + Routes: []filterapi.MCPRoute{ + { + Name: "route1", + Backends: []filterapi.MCPBackend{ + {Name: "backend1", Path: "/mcp1"}, + {Name: "backend2", Path: "/mcp2"}, // Added backend + }, + }, + }, + }, + } + + err := proxy.LoadConfig(t.Context(), config) + require.NoError(t, err) + + // Should receive tools changed notification + select { + case <-watcher: + // Expected + case <-time.After(100 * time.Millisecond): + t.Fatal("expected tools changed notification but didn't receive one") + } +} + +func TestLoadConfig_NoToolsChangedNotification(t *testing.T) { + toolChangeSignaler := newMultiWatcherSignaler() + watcher := toolChangeSignaler.Watch() + + // Initialize proxy with initial configuration directly + proxy := &ProxyConfig{ + mcpProxyConfig: &mcpProxyConfig{ + backendListenerAddr: "http://localhost:8080", + routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{ + "route1": { + backends: map[filterapi.MCPBackendName]filterapi.MCPBackend{ + "backend1": {Name: "backend1", Path: "/mcp1"}, + }, + toolSelectors: map[filterapi.MCPBackendName]*toolSelector{}, + }, + }, + }, + toolChangeSignaler: toolChangeSignaler, + } + + // Update with same backends but different BackendListenerAddr (tools NOT changed) + config := &filterapi.Config{ + MCPConfig: &filterapi.MCPConfig{ + BackendListenerAddr: "http://localhost:9090", // Different address + Routes: []filterapi.MCPRoute{ + { + Name: "route1", + Backends: []filterapi.MCPBackend{ + {Name: "backend1", Path: "/mcp1"}, // Same backend + }, + }, + }, + }, + } + + err := proxy.LoadConfig(t.Context(), config) + require.NoError(t, err) + + // Should NOT receive tools changed notification + select { + case <-watcher: + t.Fatal("unexpected tools changed notification") + case <-time.After(100 * time.Millisecond): + // Expected - no notification + } +} + +func TestLoadConfig_InvalidRegex(t *testing.T) { + proxy := &ProxyConfig{ + mcpProxyConfig: &mcpProxyConfig{}, + toolChangeSignaler: newMultiWatcherSignaler(), + } + + config := &filterapi.Config{ + MCPConfig: &filterapi.MCPConfig{ + BackendListenerAddr: "http://localhost:8080", + Routes: []filterapi.MCPRoute{ + { + Name: "route1", + Backends: []filterapi.MCPBackend{ + { + Name: "backend1", + Path: "/mcp1", + ToolSelector: &filterapi.MCPToolSelector{ + IncludeRegex: []string{"[invalid"}, // Invalid regex + }, + }, + }, + }, + }, + }, + } + + err := proxy.LoadConfig(t.Context(), config) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to compile include regex") +} + +func TestLoadConfig_ToolSelectorChange(t *testing.T) { + toolChangeSignaler := newMultiWatcherSignaler() + watcher := toolChangeSignaler.Watch() + + // Initialize proxy with initial configuration directly + proxy := &ProxyConfig{ + mcpProxyConfig: &mcpProxyConfig{ + backendListenerAddr: "http://localhost:8080", + routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{ + "route1": { + backends: map[filterapi.MCPBackendName]filterapi.MCPBackend{ + "backend1": {Name: "backend1", Path: "/mcp1"}, + }, + toolSelectors: map[filterapi.MCPBackendName]*toolSelector{ + "backend1": { + include: map[string]struct{}{"tool1": {}}, + }, + }, + }, + }, + }, + toolChangeSignaler: toolChangeSignaler, + } + + // Update with different tool selector (tools changed) + config := &filterapi.Config{ + MCPConfig: &filterapi.MCPConfig{ + BackendListenerAddr: "http://localhost:8080", + Routes: []filterapi.MCPRoute{ + { + Name: "route1", + Backends: []filterapi.MCPBackend{ + { + Name: "backend1", + Path: "/mcp1", + ToolSelector: &filterapi.MCPToolSelector{ + Include: []string{"tool1", "tool2"}, // Different tools + }, + }, + }, + }, + }, + }, + } + + // Start watcher goroutines to make sure all of them are notified + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Go(func() { + select { + case <-watcher: // Expected + case <-time.After(100 * time.Millisecond): + t.Fatal("expected tools changed notification but didn't receive one") + } + }) + } + + err := proxy.LoadConfig(t.Context(), config) + require.NoError(t, err) + + wg.Wait() +} + +func TestLoadConfig_ToolOrderDoesNotMatter(t *testing.T) { + toolChangeSignaler := newMultiWatcherSignaler() + watcher := toolChangeSignaler.Watch() + + // Initialize proxy with initial configuration directly + proxy := &ProxyConfig{ + mcpProxyConfig: &mcpProxyConfig{ + backendListenerAddr: "http://localhost:8080", + routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{ + "route1": { + backends: map[filterapi.MCPBackendName]filterapi.MCPBackend{ + "backend1": {Name: "backend1", Path: "/mcp1"}, + }, + toolSelectors: map[filterapi.MCPBackendName]*toolSelector{ + "backend1": { + include: map[string]struct{}{ + "tool-a": {}, + "tool-b": {}, + "tool-c": {}, + }, + includeRegexps: []*regexp.Regexp{ + regexp.MustCompile("^prefix.*"), + regexp.MustCompile(".*suffix$"), + regexp.MustCompile("^exact$"), + }, + }, + }, + }, + }, + }, + toolChangeSignaler: toolChangeSignaler, + } + + // Update with same tools and regexps but in different order + config := &filterapi.Config{ + MCPConfig: &filterapi.MCPConfig{ + BackendListenerAddr: "http://localhost:8080", + Routes: []filterapi.MCPRoute{ + { + Name: "route1", + Backends: []filterapi.MCPBackend{ + { + Name: "backend1", + Path: "/mcp1", + ToolSelector: &filterapi.MCPToolSelector{ + Include: []string{"tool-c", "tool-a", "tool-b"}, // Different order + IncludeRegex: []string{"^exact$", ".*suffix$", "^prefix.*"}, // Different order + }, + }, + }, + }, + }, + }, + } + + err := proxy.LoadConfig(t.Context(), config) + require.NoError(t, err) + + // Should NOT receive tools changed notification since same tools, just different order + select { + case <-watcher: + t.Fatal("unexpected tools changed notification when only order changed") + case <-time.After(100 * time.Millisecond): + // Expected - no notification + } + + // Verify the tool selector still works correctly regardless of order + route := proxy.routes["route1"] + require.NotNil(t, route) + selector := route.toolSelectors["backend1"] + require.NotNil(t, selector) + require.Contains(t, selector.include, "tool-a") + require.Contains(t, selector.include, "tool-b") + require.Contains(t, selector.include, "tool-c") + require.Len(t, selector.includeRegexps, 3) +} diff --git a/internal/mcpproxy/handlers.go b/internal/mcpproxy/handlers.go index 34504cf1ab..828cbe3a38 100644 --- a/internal/mcpproxy/handlers.go +++ b/internal/mcpproxy/handlers.go @@ -67,7 +67,7 @@ func (m *MCPProxy) serveGET(w http.ResponseWriter, r *http.Request) { w.Header().Set("Connection", "keep-alive") w.Header().Set("transfer-encoding", "chunked") w.WriteHeader(http.StatusAccepted) - if err := s.streamNotifications(r.Context(), w); err != nil && !errors.Is(err, context.Canceled) { + if err := s.streamNotifications(r.Context(), w, m.toolChangeSignaler); err != nil && !errors.Is(err, context.Canceled) { m.l.Error("failed to collect notifications", slog.String("session_id", sessionID), slog.String("error", err.Error())) http.Error(w, "failed to collect notifications", http.StatusInternalServerError) return @@ -98,6 +98,15 @@ func onErrorResponse(w http.ResponseWriter, status int, msg string) { _, _ = w.Write([]byte(msg)) } +// doNotForwardResponseToBackends checks whether the given response doesn't need to be forwarded to the backends. +// This is mostly because those are replies to Ping or ToolChange notifications initiated by hte gateway itself +// (not the backends). +func doNotForwardResponseToBackends(msg *jsonrpc.Response) bool { + str, ok := msg.ID.Raw().(string) + return ok && (strings.HasPrefix(str, envoyAIGatewayServerToClientPingRequestIDPrefix) || + strings.HasPrefix(str, envoyAIGatewayServerToClientToolsChangedRequestIDPrefix)) +} + func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() @@ -158,7 +167,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { switch msg := rawMsg.(type) { case *jsonrpc.Response: - if str, ok := msg.ID.Raw().(string); ok && strings.HasPrefix(str, envoyAIGatewayServerToClientPingRequestIDPrefix) { + if doNotForwardResponseToBackends(msg) { w.Header().Set(sessionIDHeader, string(s.clientGatewaySessionID())) w.WriteHeader(http.StatusAccepted) } else { diff --git a/internal/mcpproxy/handlers_test.go b/internal/mcpproxy/handlers_test.go index f7f362dce2..f207181189 100644 --- a/internal/mcpproxy/handlers_test.go +++ b/internal/mcpproxy/handlers_test.go @@ -47,7 +47,8 @@ func newTestMCPProxyWithTracer(t tracingapi.MCPTracer) *MCPProxy { sessionCrypto := NewPBKDF2AesGcmSessionCrypto("test", 100) return &MCPProxy{ - sessionCrypto: sessionCrypto, + sessionCrypto: sessionCrypto, + toolChangeSignaler: newMultiWatcherSignaler(), mcpProxyConfig: &mcpProxyConfig{ backendListenerAddr: "http://test-backend", routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{ diff --git a/internal/mcpproxy/mcpproxy.go b/internal/mcpproxy/mcpproxy.go index 818d3918e6..35fdc6c07d 100644 --- a/internal/mcpproxy/mcpproxy.go +++ b/internal/mcpproxy/mcpproxy.go @@ -14,7 +14,6 @@ import ( "io" "log/slog" "net/http" - "regexp" "strings" "sync" "time" @@ -28,62 +27,22 @@ import ( tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" ) -type ( - // ProxyConfig holds the main MCP proxy configuration. - ProxyConfig struct { - *mcpProxyConfig - } - - // MCPProxy serves /mcp endpoint. - // - // This implements [extproc.ConfigReceiver] to gets the up-to-date configuration. - MCPProxy struct { - *mcpProxyConfig - metrics metrics.MCPMetrics - l *slog.Logger - sessionCrypto SessionCrypto - tracer tracing.MCPTracer - } - - mcpProxyConfig struct { - backendListenerAddr string - routes map[filterapi.MCPRouteName]*mcpProxyConfigRoute // route name -> backends of that route. - } - - mcpProxyConfigRoute struct { - backends map[filterapi.MCPBackendName]filterapi.MCPBackend - toolSelectors map[filterapi.MCPBackendName]*toolSelector - } - - // toolSelector filters tools using include patterns with exact matches or regular expressions. - toolSelector struct { - include map[string]struct{} - includeRegexps []*regexp.Regexp - } -) - -func (f *toolSelector) allows(tool string) bool { - // Check include filters - if no filter, allow all; if filter exists, allow only matches - if len(f.include) > 0 { - _, ok := f.include[tool] - return ok - } - if len(f.includeRegexps) > 0 { - for _, re := range f.includeRegexps { - if re.MatchString(tool) { - return true - } - } - return false - } - - // No filters, allow all - return true +// MCPProxy serves /mcp endpoint. +// +// This implements [extproc.ConfigReceiver] to gets the up-to-date configuration. +type MCPProxy struct { + *mcpProxyConfig + metrics metrics.MCPMetrics + l *slog.Logger + sessionCrypto SessionCrypto + tracer tracing.MCPTracer + toolChangeSignaler changeSignaler // signals tool changes to active sessions. } // NewMCPProxy creates a new MCPProxy instance. func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.MCPTracer, sessionCrypto SessionCrypto) (*ProxyConfig, *http.ServeMux, error) { - cfg := &ProxyConfig{} + toolChangeSignaler := newMultiWatcherSignaler() // used to signal changes to all active sessions. + cfg := &ProxyConfig{toolChangeSignaler: toolChangeSignaler} mux := http.NewServeMux() mux.HandleFunc( // Must match all paths since the route selection happens at Envoy level and the "route" header is already @@ -93,11 +52,12 @@ func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.M // with different prefixes will not be matched, which is not what we want. "/", func(w http.ResponseWriter, r *http.Request) { proxy := &MCPProxy{ - mcpProxyConfig: cfg.mcpProxyConfig, - l: l, - metrics: mcpMetrics.WithRequestAttributes(r), - tracer: tracer, - sessionCrypto: sessionCrypto, + mcpProxyConfig: cfg.mcpProxyConfig, + l: l, + metrics: mcpMetrics.WithRequestAttributes(r), + tracer: tracer, + sessionCrypto: sessionCrypto, + toolChangeSignaler: toolChangeSignaler, } switch r.Method { @@ -114,54 +74,6 @@ func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.M return cfg, mux, nil } -// LoadConfig implements [extproc.ConfigReceiver.LoadConfig] which will be called -// when the configuration is updated on the file system. -func (p *ProxyConfig) LoadConfig(_ context.Context, config *filterapi.Config) error { - newConfig := &mcpProxyConfig{} - mcpConfig := config.MCPConfig - if config.MCPConfig == nil { - return nil - } - - // Talk to the backend MCP listener on the local Envoy instance. - newConfig.backendListenerAddr = mcpConfig.BackendListenerAddr - - // Build a map of routes to backends. - // Each route has its own set of backends. For a given downstream request, - // the MCP proxy initializes sessions only with the backends tied to that route. - newConfig.routes = make(map[filterapi.MCPRouteName]*mcpProxyConfigRoute, len(mcpConfig.Routes)) - - for _, route := range mcpConfig.Routes { - r := &mcpProxyConfigRoute{ - backends: make(map[filterapi.MCPBackendName]filterapi.MCPBackend, len(route.Backends)), - toolSelectors: make(map[filterapi.MCPBackendName]*toolSelector, len(route.Backends)), - } - for _, backend := range route.Backends { - r.backends[backend.Name] = backend - if s := backend.ToolSelector; s != nil { - ts := &toolSelector{ - include: make(map[string]struct{}), - } - for _, tool := range s.Include { - ts.include[tool] = struct{}{} - } - for _, expr := range s.IncludeRegex { - re, err := regexp.Compile(expr) - if err != nil { - return fmt.Errorf("failed to compile include regex %q for backend %q in route %q: %w", expr, backend.Name, route.Name, err) - } - ts.includeRegexps = append(ts.includeRegexps, re) - } - r.toolSelectors[backend.Name] = ts - } - } - newConfig.routes[route.Name] = r - } - - p.mcpProxyConfig = newConfig // This is racy, but we don't care. - return nil -} - // newSession creates a new session for a downstream client. // It multiplexes the initialize request to all backends defined in the MCPRoute associated with the downstream request. func (m *MCPProxy) newSession(ctx context.Context, p *mcp.InitializeParams, routeName filterapi.MCPRouteName, subject string, span tracing.MCPSpan) (*session, error) { diff --git a/internal/mcpproxy/mcpproxy_test.go b/internal/mcpproxy/mcpproxy_test.go index 724b02e8fd..1274abdf99 100644 --- a/internal/mcpproxy/mcpproxy_test.go +++ b/internal/mcpproxy/mcpproxy_test.go @@ -11,7 +11,6 @@ import ( "log/slog" "net/http" "net/http/httptest" - "regexp" "sync" "testing" @@ -83,16 +82,6 @@ func TestMCPProxy_HTTPMethods(t *testing.T) { require.Contains(t, rr.Body.String(), "method not allowed") } -func TestLoadConfig_NilMCPConfig(t *testing.T) { - proxy, _, err := NewMCPProxy(slog.Default(), stubMetrics{}, noopTracer, NewPBKDF2AesGcmSessionCrypto("test", 100)) - require.NoError(t, err) - - config := &filterapi.Config{MCPConfig: nil} - - err = proxy.LoadConfig(t.Context(), config) - require.NoError(t, err) -} - const ( validInitializeResponse = `{ "jsonrpc": "2.0", @@ -352,40 +341,3 @@ func TestInvokeJSONRPCRequest_NoSessionID(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) require.NoError(t, resp.Body.Close()) } - -func Test_toolSelector_Allows(t *testing.T) { - reBa := regexp.MustCompile("^ba.*") - tests := []struct { - name string - selector toolSelector - tools []string - want []bool - }{ - { - name: "no rules allows all", - selector: toolSelector{}, - tools: []string{"foo", "bar"}, - want: []bool{true, true}, - }, - { - name: "include specific tool", - selector: toolSelector{include: map[string]struct{}{"foo": {}}}, - tools: []string{"foo", "bar"}, - want: []bool{true, false}, - }, - { - name: "include regexp", - selector: toolSelector{includeRegexps: []*regexp.Regexp{reBa}}, - tools: []string{"bar", "foo"}, - want: []bool{true, false}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - for i, tool := range tt.tools { - got := tt.selector.allows(tool) - require.Equalf(t, tt.want[i], got, "tool: %s", tool) - } - }) - } -} diff --git a/internal/mcpproxy/session.go b/internal/mcpproxy/session.go index c37ef9fdee..6335e6ca04 100644 --- a/internal/mcpproxy/session.go +++ b/internal/mcpproxy/session.go @@ -157,8 +157,10 @@ func (s *session) lastEventID() string { } var ( - envoyAIGatewayServerToClientPingRequestIDPrefix = "aigw-server-to-client-ping" - pingParam, _ = json.Marshal(&mcpsdk.PingParams{}) + envoyAIGatewayServerToClientPingRequestIDPrefix = "aigw-server-to-client-ping" + envoyAIGatewayServerToClientToolsChangedRequestIDPrefix = "aigw-server-to-client-tools-changed" + pingParam, _ = json.Marshal(&mcpsdk.PingParams{}) + toolsChangedParam, _ = json.Marshal(&mcpsdk.ToolListChangedParams{}) ) func newHeartBeatPingMessage() *jsonrpc.Request { @@ -171,12 +173,21 @@ func newHeartBeatPingMessage() *jsonrpc.Request { } } +func newToolListChangedMessage() *jsonrpc.Request { + id, _ := jsonrpc.MakeID(envoyAIGatewayServerToClientToolsChangedRequestIDPrefix + uuid.NewString()) + return &jsonrpc.Request{ + ID: id, + Method: "notifications/tools/list_changed", + Params: toolsChangedParam, + } +} + // streamNotifications streams notifications from all backends in this session to the given writer. -func (s *session) streamNotifications(ctx context.Context, w http.ResponseWriter) error { +func (s *session) streamNotifications(ctx context.Context, w http.ResponseWriter, toolChangeSignaler changeSignaler) error { backendMsgs := s.sendToAllBackends(ctx, http.MethodGet, nil, nil) // Create a ticker for periodic heartbeat events to avoid HTTP timeouts. - // This also helps unblock Goose at startup — it looks like Goose is waiting for the first SSE event before proceeding. + // This also helps unblock Goose at startup - it looks like Goose is waiting for the first SSE event before proceeding. // // TODO: no idea exactly why this is necessary. Goose shouldn't block on the first event. @@ -198,6 +209,7 @@ func (s *session) streamNotifications(ctx context.Context, w http.ResponseWriter for { select { + // events received from the upstream MCP backends case event, ok := <-backendMsgs: if !ok { // Channel closed, all backends have finished. @@ -228,6 +240,16 @@ func (s *session) streamNotifications(ctx context.Context, w http.ResponseWriter if heartbeatTicker != nil { heartbeatTicker.Reset(heartbeatInterval) } + // toolChangeSignaler will trigger when the tools configured in the MCP routes change. + // This is not related to upstream MCP server changes, but to the tools configured in the gateway. + case <-toolChangeSignaler.Watch(): + toolChangeEvent := &sseEvent{event: "message", messages: []jsonrpc.Message{newToolListChangedMessage()}} + toolChangeEvent.writeAndMaybeFlush(w) + // Reset the heartbeat ticker so that the next heartbeat will be sent after the full interval. + // This avoids sending heartbeats too frequently when there are events. + if heartbeatTicker != nil { + heartbeatTicker.Reset(heartbeatInterval) + } case <-heartbeats: heartBeatEvent := &sseEvent{event: "message", messages: []jsonrpc.Message{newHeartBeatPingMessage()}} heartBeatEvent.writeAndMaybeFlush(w) diff --git a/internal/mcpproxy/session_test.go b/internal/mcpproxy/session_test.go index 337e77968f..2e850100a1 100644 --- a/internal/mcpproxy/session_test.go +++ b/internal/mcpproxy/session_test.go @@ -228,7 +228,7 @@ func TestSession_StreamNotifications(t *testing.T) { rr := httptest.NewRecorder() ctx, cancel := context.WithTimeout(t.Context(), tc.deadline) defer cancel() - err2 := s.streamNotifications(ctx, rr) + err2 := s.streamNotifications(ctx, rr, proxy.toolChangeSignaler) require.NoError(t, err2) out := rr.Body.String() require.Contains(t, out, "event: a1") @@ -244,6 +244,62 @@ func TestSession_StreamNotifications(t *testing.T) { } } +func TestNotifyToolsChanged(t *testing.T) { + var ( + reloadConfig atomic.Bool + proxy = newTestMCPProxy() + cfg = ProxyConfig{ + toolChangeSignaler: proxy.toolChangeSignaler, + mcpProxyConfig: proxy.mcpProxyConfig, + } + s = &session{ + proxy: proxy, + route: "test-route", + perBackendSessions: map[filterapi.MCPBackendName]*compositeSessionEntry{ + "backend1": {sessionID: "s1"}, + }, + } + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + // if the test wants to reload config, trigger it once the stream is open, to better simulate + // changes when there is an active streaming session. + // wait a bit and trigger the config change. + if reloadConfig.Load() { + time.Sleep(50 * time.Millisecond) + require.NoError(t, cfg.LoadConfig(t.Context(), + // Clear all the routes -> should trigger a tools changed notification. + &filterapi.Config{MCPConfig: &filterapi.MCPConfig{}}), + ) + } + })) + proxy.backendListenerAddr = srv.URL + + t.Run("no tool changes by default", func(t *testing.T) { + rr := httptest.NewRecorder() + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) + t.Cleanup(cancel) + err := s.streamNotifications(ctx, rr, proxy.toolChangeSignaler) + require.NoError(t, err) + out := rr.Body.String() + require.NotContains(t, out, `"id":"`+envoyAIGatewayServerToClientToolsChangedRequestIDPrefix) + require.NotContains(t, out, `"method":"notifications/tools/list_changed"`) + }) + + t.Run("notify tools changed", func(t *testing.T) { + reloadConfig.Store(true) + rr := httptest.NewRecorder() + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) + t.Cleanup(cancel) + err := s.streamNotifications(ctx, rr, proxy.toolChangeSignaler) + require.NoError(t, err) + out := rr.Body.String() + require.Contains(t, out, `"id":"`+envoyAIGatewayServerToClientToolsChangedRequestIDPrefix) + require.Contains(t, out, `"method":"notifications/tools/list_changed"`) + }) +} + func TestSendRequestPerBackend_ErrorStatus(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) diff --git a/tests/bench/bench_test.go b/tests/bench/bench_test.go index 61b45dfde3..fb31ac6908 100644 --- a/tests/bench/bench_test.go +++ b/tests/bench/bench_test.go @@ -49,7 +49,7 @@ func setupBenchmark(b *testing.B) []MCPBenchCase { b.Helper() // Treat this as a helper function // setup MCP server - mcpServer := testmcp.NewServer(&testmcp.Options{ + mcpServer, _ := testmcp.NewServer(&testmcp.Options{ Port: mcpServerPort, ForceJSONResponse: false, DumbEchoServer: true, diff --git a/tests/extproc/mcp/env.go b/tests/extproc/mcp/env.go index e45d867f94..306cadf32f 100644 --- a/tests/extproc/mcp/env.go +++ b/tests/extproc/mcp/env.go @@ -12,6 +12,7 @@ import ( "fmt" "io" "maps" + "net/http" "sync" "testing" "time" @@ -37,6 +38,7 @@ var envoyConfig string // TODO: move this to testmcp package so that we could reuse the same tests in the end-to-end tests. type mcpEnv struct { client *mcp.Client + mcp1, mcp2 *mcp.Server mux sync.Mutex extProcMetricsURL string baseURL string @@ -55,6 +57,7 @@ type mcpEnv struct { type mcpSession struct { session *mcp.ClientSession // TODO: merge them into one chan for simplicity? + toolListChangedNotifications chan *mcp.ToolListChangedRequest progressNotifications chan *mcp.ProgressNotificationClientRequest promptListChangedNotifications chan *mcp.PromptListChangedRequest resourceUpdatedNotifications chan *mcp.ResourceUpdatedNotificationRequest @@ -109,15 +112,17 @@ func requireNewMCPEnv(t *testing.T, forceJSONResponse bool, writeTimeout time.Du config, err := json.Marshal(filterapi.Config{MCPConfig: mcpConfig}) require.NoError(t, err) + var mcp1, mcp2 *mcp.Server env := testenvironment.StartTestEnvironment(t, func(_ testing.TB, _ io.Writer, ports map[string]int) { - srv1 := testmcp.NewServer(&testmcp.Options{ + var srv1, srv2 *http.Server + srv1, mcp1 = testmcp.NewServer(&testmcp.Options{ Port: ports["ts1"], ForceJSONResponse: forceJSONResponse, DumbEchoServer: false, WriteTimeout: writeTimeout, }) - srv2 := testmcp.NewServer(&testmcp.Options{ + srv2, mcp2 = testmcp.NewServer(&testmcp.Options{ Port: ports["ts2"], ForceJSONResponse: forceJSONResponse, DumbEchoServer: true, @@ -133,12 +138,23 @@ func requireNewMCPEnv(t *testing.T, forceJSONResponse bool, writeTimeout time.Du ) m := new(mcpEnv) + m.mcp1, m.mcp2 = mcp1, mcp2 m.collector = collector m.writeTimeout = writeTimeout m.extProcMetricsURL = fmt.Sprintf("http://localhost:%d/metrics", env.ExtProcAdminPort()) m.baseURL = fmt.Sprintf("http://localhost:%d%s", env.EnvoyListenerPort(), path) m.client = mcp.NewClient(&mcp.Implementation{Name: "demo-http-client", Version: "0.1.0"}, &mcp.ClientOptions{ + ToolListChangedHandler: func(_ context.Context, request *mcp.ToolListChangedRequest) { + m.mux.Lock() + defer m.mux.Unlock() + if sess, ok := m.sessions[request.GetSession().ID()]; ok { + t.Log("received tool list change for session ", request.GetSession().ID(), ": ", request.Params) + sess.toolListChangedNotifications <- request + } else { + t.Fatalf("received tool list change for unknown session ID %q", request.GetSession().ID()) + } + }, // TODO: this is due to how the official go-sdk is designed. Notification is a per-session concept but // they force the handler to be per-client, which resulted in forcing us to do this multiplexing here. ProgressNotificationHandler: func(_ context.Context, request *mcp.ProgressNotificationClientRequest) { @@ -236,6 +252,7 @@ func requireNewMCPEnv(t *testing.T, forceJSONResponse bool, writeTimeout time.Du // newSession creates a new MCP client session and registers it for progress notifications. func (m *mcpEnv) newSession(t *testing.T) *mcpSession { ret := &mcpSession{ + toolListChangedNotifications: make(chan *mcp.ToolListChangedRequest, 100), progressNotifications: make(chan *mcp.ProgressNotificationClientRequest, 100), promptListChangedNotifications: make(chan *mcp.PromptListChangedRequest, 100), resourceUpdatedNotifications: make(chan *mcp.ResourceUpdatedNotificationRequest, 100), @@ -245,6 +262,7 @@ func (m *mcpEnv) newSession(t *testing.T) *mcpSession { elicitRequests: make(chan *mcp.ElicitRequest, 100), } var err error + ret.session, err = m.client.Connect(t.Context(), &mcp.StreamableClientTransport{Endpoint: m.baseURL}, nil) require.NoError(t, err) span := m.collector.TakeSpan() @@ -263,6 +281,7 @@ func (m *mcpEnv) newSession(t *testing.T) *mcpSession { m.mux.Lock() defer m.mux.Unlock() delete(m.sessions, ret.session.ID()) + close(ret.toolListChangedNotifications) close(ret.progressNotifications) close(ret.promptListChangedNotifications) close(ret.resourceUpdatedNotifications) diff --git a/tests/extproc/mcp/mcp_test.go b/tests/extproc/mcp/mcp_test.go index c280d3ed96..a207a9d8e0 100644 --- a/tests/extproc/mcp/mcp_test.go +++ b/tests/extproc/mcp/mcp_test.go @@ -45,6 +45,7 @@ var tests = []struct { {name: "ToolCallDumbEcho", testFn: testToolCallDumbEcho}, {name: "ToolCallError", testFn: testToolCallError}, {name: "ToolCountDown", testFn: testToolCountDown}, + {name: "ToolChangeNotifications", testFn: testToolChangeNotifications}, {name: "Ping", testFn: testPing}, {name: "LoggingSetLevel", testFn: testLoggingSetLevel}, {name: "ListPrompts", testFn: testListPrompts}, @@ -135,6 +136,32 @@ func testListToolsRequireOnlyDumb(t *testing.T, m *mcpEnv) { }) } +func testToolChangeNotifications(t *testing.T, m *mcpEnv) { + s := m.newSession(t) + + requireToolChangedNotification := func(t *testing.T) { + var req *mcp.ToolListChangedRequest + select { + case req = <-s.toolListChangedNotifications: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for tool list change notification") + } + require.NotNil(t, req) + require.NotNil(t, req.Params) + require.IsTypef(t, &mcp.ToolListChangedParams{}, req.Params, "expected ToolListChangedParams, got %T", req.Params) + } + + t.Run("tool add", func(t *testing.T) { + mcp.AddTool(m.mcp1, testmcp.ToolDumbEcho.Tool, testmcp.ToolDumbEcho.Handler) + requireToolChangedNotification(t) + }) + + t.Run("tool remove", func(t *testing.T) { + m.mcp1.RemoveTools(testmcp.ToolDumbEcho.Tool.Name) + requireToolChangedNotification(t) + }) +} + func testToolCall(t *testing.T, m *mcpEnv) { s := m.newSession(t) diff --git a/tests/internal/testmcp/server.go b/tests/internal/testmcp/server.go index c8f811205b..c767639861 100644 --- a/tests/internal/testmcp/server.go +++ b/tests/internal/testmcp/server.go @@ -38,7 +38,7 @@ type Options struct { // When dumbEchoServer is true, the server will only implement the echo tool, // and will not implement any prompts or resources. This is useful for testing // basic routing. -func NewServer(opts *Options) *http.Server { +func NewServer(opts *Options) (*http.Server, *mcp.Server) { if opts.DumbEchoServer { return newDumbServer(opts.Port) } @@ -143,10 +143,10 @@ func NewServer(opts *Options) *http.Server { log.Fatalf("server error: %v", err) } }() - return server + return server, s } -func newDumbServer(port int) *http.Server { +func newDumbServer(port int) (*http.Server, *mcp.Server) { s := mcp.NewServer( &mcp.Implementation{Name: "dumb-echo-server", Version: "0.1.0"}, &mcp.ServerOptions{HasTools: true}, @@ -161,5 +161,5 @@ func newDumbServer(port int) *http.Server { log.Fatalf("server error: %v", err) } }() - return server + return server, s } diff --git a/tests/internal/testmcp/testmcpserver/main.go b/tests/internal/testmcp/testmcpserver/main.go index 8f51951faf..150473c0b5 100644 --- a/tests/internal/testmcp/testmcpserver/main.go +++ b/tests/internal/testmcp/testmcpserver/main.go @@ -40,8 +40,9 @@ func doMain() *http.Server { if err != nil { logger.Fatalf("invalid port: %v", err) } - return testmcp.NewServer(&testmcp.Options{ + server, _ := testmcp.NewServer(&testmcp.Options{ Port: port, WriteTimeout: 1200 * time.Second, }) + return server }