Skip to content

Commit 8d040bc

Browse files
committed
mcp: send a tools changed notification when mcp routes change tools
Signed-off-by: Ignasi Barrera <[email protected]>
1 parent d33eec2 commit 8d040bc

File tree

11 files changed

+533
-28
lines changed

11 files changed

+533
-28
lines changed

internal/mcpproxy/handlers.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func (m *MCPProxy) serveGET(w http.ResponseWriter, r *http.Request) {
6767
w.Header().Set("Connection", "keep-alive")
6868
w.Header().Set("transfer-encoding", "chunked")
6969
w.WriteHeader(http.StatusAccepted)
70-
if err := s.streamNotifications(r.Context(), w); err != nil && !errors.Is(err, context.Canceled) {
70+
if err := s.streamNotifications(r.Context(), w, m.toolsChangedChan); err != nil && !errors.Is(err, context.Canceled) {
7171
m.l.Error("failed to collect notifications", slog.String("session_id", sessionID), slog.String("error", err.Error()))
7272
http.Error(w, "failed to collect notifications", http.StatusInternalServerError)
7373
return
@@ -98,6 +98,15 @@ func onErrorResponse(w http.ResponseWriter, status int, msg string) {
9898
_, _ = w.Write([]byte(msg))
9999
}
100100

101+
// doNotForwardResponseToBackends checks whether the given response doesn't need to be forwarded to the backends.
102+
// This is mostly because those are replies to Ping or ToolChange notifications initiated by hte gateway itself
103+
// (not the backends).
104+
func doNotForwardResponseToBackends(msg *jsonrpc.Response) bool {
105+
str, ok := msg.ID.Raw().(string)
106+
return ok && (strings.HasPrefix(str, envoyAIGatewayServerToClientPingRequestIDPrefix) ||
107+
strings.HasPrefix(str, envoyAIGatewayServerToClientToolsChangedRequestIDPrefix))
108+
}
109+
101110
func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
102111
var (
103112
ctx = r.Context()
@@ -158,7 +167,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
158167

159168
switch msg := rawMsg.(type) {
160169
case *jsonrpc.Response:
161-
if str, ok := msg.ID.Raw().(string); ok && strings.HasPrefix(str, envoyAIGatewayServerToClientPingRequestIDPrefix) {
170+
if doNotForwardResponseToBackends(msg) {
162171
w.Header().Set(sessionIDHeader, string(s.clientGatewaySessionID()))
163172
w.WriteHeader(http.StatusAccepted)
164173
} else {

internal/mcpproxy/handlers_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ func newTestMCPProxyWithTracer(t tracingapi.MCPTracer) *MCPProxy {
4747
sessionCrypto := NewPBKDF2AesGcmSessionCrypto("test", 100)
4848

4949
return &MCPProxy{
50-
sessionCrypto: sessionCrypto,
50+
sessionCrypto: sessionCrypto,
51+
toolsChangedChan: make(chan struct{}, 1),
5152
mcpProxyConfig: &mcpProxyConfig{
5253
backendListenerAddr: "http://test-backend",
5354
routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{

internal/mcpproxy/mcpproxy.go

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ import (
1313
"fmt"
1414
"io"
1515
"log/slog"
16+
"maps"
1617
"net/http"
1718
"regexp"
19+
"slices"
1820
"strings"
1921
"sync"
2022
"time"
@@ -32,6 +34,7 @@ type (
3234
// ProxyConfig holds the main MCP proxy configuration.
3335
ProxyConfig struct {
3436
*mcpProxyConfig
37+
toolsChangedChan chan<- struct{} // channel to notify tool changes to clients
3538
}
3639

3740
// MCPProxy serves /mcp endpoint.
@@ -43,6 +46,8 @@ type (
4346
l *slog.Logger
4447
sessionCrypto SessionCrypto
4548
tracer tracing.MCPTracer
49+
50+
toolsChangedChan chan struct{}
4651
}
4752

4853
mcpProxyConfig struct {
@@ -62,28 +67,70 @@ type (
6267
}
6368
)
6469

65-
func (f *toolSelector) allows(tool string) bool {
70+
func (m *mcpProxyConfig) sameTools(other *mcpProxyConfig) bool {
71+
if m == nil || other == nil {
72+
return m == other
73+
}
74+
return maps.EqualFunc(m.routes, other.routes, func(a, b *mcpProxyConfigRoute) bool {
75+
return a.sameTools(b)
76+
})
77+
}
78+
79+
func (m *mcpProxyConfigRoute) sameTools(other *mcpProxyConfigRoute) bool {
80+
if m == nil || other == nil {
81+
return m == other
82+
}
83+
if !equalKeys(m.backends, other.backends) {
84+
return false
85+
}
86+
return maps.EqualFunc(m.toolSelectors, other.toolSelectors, func(a, b *toolSelector) bool {
87+
return a.sameTools(b)
88+
})
89+
}
90+
91+
var sortRegexpAsString = func(a, b *regexp.Regexp) int { return strings.Compare(a.String(), b.String()) }
92+
93+
func equalKeys[K comparable, V any](m1, m2 map[K]V) bool {
94+
return maps.EqualFunc(m1, m2, func(_, _ V) bool { return true })
95+
}
96+
97+
func (t *toolSelector) sameTools(other *toolSelector) bool {
98+
if t == nil || other == nil {
99+
return t == other
100+
}
101+
if !equalKeys(t.include, other.include) {
102+
return false
103+
}
104+
slices.SortFunc(t.includeRegexps, sortRegexpAsString)
105+
slices.SortFunc(other.includeRegexps, sortRegexpAsString)
106+
return slices.EqualFunc(t.includeRegexps, other.includeRegexps,
107+
func(a, b *regexp.Regexp) bool {
108+
return a.String() == b.String()
109+
})
110+
}
111+
112+
func (t *toolSelector) allows(tool string) bool {
66113
// Check include filters - if no filter, allow all; if filter exists, allow only matches
67-
if len(f.include) > 0 {
68-
_, ok := f.include[tool]
114+
if len(t.include) > 0 {
115+
_, ok := t.include[tool]
69116
return ok
70117
}
71-
if len(f.includeRegexps) > 0 {
72-
for _, re := range f.includeRegexps {
118+
if len(t.includeRegexps) > 0 {
119+
for _, re := range t.includeRegexps {
73120
if re.MatchString(tool) {
74121
return true
75122
}
76123
}
77124
return false
78125
}
79-
80126
// No filters, allow all
81127
return true
82128
}
83129

84130
// NewMCPProxy creates a new MCPProxy instance.
85131
func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.MCPTracer, sessionCrypto SessionCrypto) (*ProxyConfig, *http.ServeMux, error) {
86-
cfg := &ProxyConfig{}
132+
toolsChangedChan := make(chan struct{}, 1)
133+
cfg := &ProxyConfig{toolsChangedChan: toolsChangedChan}
87134
mux := http.NewServeMux()
88135
mux.HandleFunc(
89136
// Must match all paths since the route selection happens at Envoy level and the "route" header is already
@@ -93,11 +140,12 @@ func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.M
93140
// with different prefixes will not be matched, which is not what we want.
94141
"/", func(w http.ResponseWriter, r *http.Request) {
95142
proxy := &MCPProxy{
96-
mcpProxyConfig: cfg.mcpProxyConfig,
97-
l: l,
98-
metrics: mcpMetrics.WithRequestAttributes(r),
99-
tracer: tracer,
100-
sessionCrypto: sessionCrypto,
143+
mcpProxyConfig: cfg.mcpProxyConfig,
144+
l: l,
145+
metrics: mcpMetrics.WithRequestAttributes(r),
146+
tracer: tracer,
147+
sessionCrypto: sessionCrypto,
148+
toolsChangedChan: toolsChangedChan,
101149
}
102150

103151
switch r.Method {
@@ -158,7 +206,16 @@ func (p *ProxyConfig) LoadConfig(_ context.Context, config *filterapi.Config) er
158206
newConfig.routes[route.Name] = r
159207
}
160208

209+
toolsChanged := !p.sameTools(newConfig)
161210
p.mcpProxyConfig = newConfig // This is racy, but we don't care.
211+
212+
if toolsChanged {
213+
select {
214+
case p.toolsChangedChan <- struct{}{}:
215+
default: // Ignore if the channel is full.
216+
}
217+
}
218+
162219
return nil
163220
}
164221

0 commit comments

Comments
 (0)