Skip to content

Commit f289aef

Browse files
two fixes
* add context aware stdio transport (ctrl-c works again) * add ServersSessions to subscribe/unsubscribe handlers
1 parent 00e34f9 commit f289aef

File tree

18 files changed

+325
-78
lines changed

18 files changed

+325
-78
lines changed

cmd/docker-mcp/internal/gateway/capabilitites.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat
7777
}
7878
capabilities.Tools = append(capabilities.Tools, ToolRegistration{
7979
Tool: tool,
80-
Handler: g.mcpServerToolHandler(*serverConfig, tool.Annotations),
80+
Handler: g.mcpServerToolHandler(*serverConfig, g.mcpServer, tool.Annotations),
8181
})
8282
}
8383
}
@@ -87,7 +87,7 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat
8787
for _, prompt := range prompts.Prompts {
8888
capabilities.Prompts = append(capabilities.Prompts, PromptRegistration{
8989
Prompt: prompt,
90-
Handler: g.mcpServerPromptHandler(*serverConfig),
90+
Handler: g.mcpServerPromptHandler(*serverConfig, g.mcpServer),
9191
})
9292
}
9393
}
@@ -97,7 +97,7 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat
9797
for _, resource := range resources.Resources {
9898
capabilities.Resources = append(capabilities.Resources, ResourceRegistration{
9999
Resource: resource,
100-
Handler: g.mcpServerResourceHandler(*serverConfig),
100+
Handler: g.mcpServerResourceHandler(*serverConfig, g.mcpServer),
101101
})
102102
}
103103
}
@@ -107,7 +107,7 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat
107107
for _, resourceTemplate := range resourceTemplates.ResourceTemplates {
108108
capabilities.ResourceTemplates = append(capabilities.ResourceTemplates, ResourceTemplateRegistration{
109109
ResourceTemplate: *resourceTemplate,
110-
Handler: g.mcpServerResourceHandler(*serverConfig),
110+
Handler: g.mcpServerResourceHandler(*serverConfig, g.mcpServer),
111111
})
112112
}
113113
}

cmd/docker-mcp/internal/gateway/clientpool.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type clientPool struct {
3434
type clientConfig struct {
3535
readOnly *bool
3636
serverSession *mcp.ServerSession
37+
server *mcp.Server
3738
}
3839

3940
func newClientPool(options Options, docker docker.Client) *clientPool {
@@ -383,12 +384,14 @@ func (cg *clientGetter) GetClient(ctx context.Context) (mcpclient.Client, error)
383384
// Use the original context instead of creating a timeout context
384385
// to avoid cancellation issues
385386
var ss *mcp.ServerSession
387+
var server *mcp.Server
386388
if cg.clientConfig != nil {
387389
ss = cg.clientConfig.serverSession
390+
server = cg.clientConfig.server
388391
} else {
389392
ss = nil
390393
}
391-
if err := client.Initialize(ctx, initParams, cg.cp.Verbose, ss); err != nil {
394+
if err := client.Initialize(ctx, initParams, cg.cp.Verbose, ss, server); err != nil {
392395
return nil, err
393396
}
394397

cmd/docker-mcp/internal/gateway/handlers.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import (
88
"github.com/docker/mcp-gateway/cmd/docker-mcp/internal/catalog"
99
)
1010

11-
func getClientConfig(readOnlyHint *bool, ss *mcp.ServerSession) *clientConfig {
12-
return &clientConfig{readOnly: readOnlyHint, serverSession: ss}
11+
func getClientConfig(readOnlyHint *bool, ss *mcp.ServerSession, server *mcp.Server) *clientConfig {
12+
return &clientConfig{readOnly: readOnlyHint, serverSession: ss, server: server}
1313
}
1414

1515
func (g *Gateway) mcpToolHandler(tool catalog.Tool) mcp.ToolHandler {
@@ -24,7 +24,7 @@ func (g *Gateway) mcpToolHandler(tool catalog.Tool) mcp.ToolHandler {
2424
}
2525
}
2626

27-
func (g *Gateway) mcpServerToolHandler(serverConfig catalog.ServerConfig, annotations *mcp.ToolAnnotations) mcp.ToolHandler {
27+
func (g *Gateway) mcpServerToolHandler(serverConfig catalog.ServerConfig, server *mcp.Server, annotations *mcp.ToolAnnotations) mcp.ToolHandler {
2828
return func(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[map[string]any]) (*mcp.CallToolResultFor[any], error) {
2929
var readOnlyHint *bool
3030
if annotations != nil && annotations.ReadOnlyHint {
@@ -38,7 +38,7 @@ func (g *Gateway) mcpServerToolHandler(serverConfig catalog.ServerConfig, annota
3838
Arguments: params.Arguments,
3939
}
4040

41-
client, err := g.clientPool.AcquireClient(ctx, serverConfig, getClientConfig(readOnlyHint, ss))
41+
client, err := g.clientPool.AcquireClient(ctx, serverConfig, getClientConfig(readOnlyHint, ss, server))
4242
if err != nil {
4343
return nil, err
4444
}
@@ -48,9 +48,9 @@ func (g *Gateway) mcpServerToolHandler(serverConfig catalog.ServerConfig, annota
4848
}
4949
}
5050

51-
func (g *Gateway) mcpServerPromptHandler(serverConfig catalog.ServerConfig) mcp.PromptHandler {
51+
func (g *Gateway) mcpServerPromptHandler(serverConfig catalog.ServerConfig, server *mcp.Server) mcp.PromptHandler {
5252
return func(ctx context.Context, ss *mcp.ServerSession, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) {
53-
client, err := g.clientPool.AcquireClient(ctx, serverConfig, getClientConfig(nil, ss))
53+
client, err := g.clientPool.AcquireClient(ctx, serverConfig, getClientConfig(nil, ss, server))
5454
if err != nil {
5555
return nil, err
5656
}
@@ -60,9 +60,9 @@ func (g *Gateway) mcpServerPromptHandler(serverConfig catalog.ServerConfig) mcp.
6060
}
6161
}
6262

63-
func (g *Gateway) mcpServerResourceHandler(serverConfig catalog.ServerConfig) mcp.ResourceHandler {
63+
func (g *Gateway) mcpServerResourceHandler(serverConfig catalog.ServerConfig, server *mcp.Server) mcp.ResourceHandler {
6464
return func(ctx context.Context, ss *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) {
65-
client, err := g.clientPool.AcquireClient(ctx, serverConfig, getClientConfig(nil, ss))
65+
client, err := g.clientPool.AcquireClient(ctx, serverConfig, getClientConfig(nil, ss, server))
6666
if err != nil {
6767
return nil, err
6868
}

cmd/docker-mcp/internal/gateway/run.go

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net"
77
"os"
88
"strings"
9+
"sync"
910
"time"
1011

1112
"github.com/modelcontextprotocol/go-sdk/mcp"
@@ -15,13 +16,34 @@ import (
1516
"github.com/docker/mcp-gateway/cmd/docker-mcp/internal/interceptors"
1617
)
1718

19+
type ServerSessionCache struct {
20+
Roots []*mcp.Root
21+
}
22+
23+
type SubsAction int
24+
25+
const (
26+
subscribe SubsAction = 0
27+
unsubscribe SubsAction = 1
28+
)
29+
30+
type SubsMessage struct {
31+
uri string
32+
action SubsAction
33+
ss *mcp.ServerSession
34+
}
35+
1836
type Gateway struct {
1937
Options
2038
docker docker.Client
2139
configurator Configurator
2240
clientPool *clientPool
2341
mcpServer *mcp.Server
2442
health health.State
43+
subsChannel chan SubsMessage
44+
45+
sessionCacheMu sync.RWMutex
46+
sessionCache map[*mcp.ServerSession]*ServerSessionCache
2547
}
2648

2749
func NewGateway(config Config, docker docker.Client) *Gateway {
@@ -39,12 +61,20 @@ func NewGateway(config Config, docker docker.Client) *Gateway {
3961
Central: config.Central,
4062
docker: docker,
4163
},
42-
clientPool: newClientPool(config.Options, docker),
64+
clientPool: newClientPool(config.Options, docker),
65+
sessionCache: make(map[*mcp.ServerSession]*ServerSessionCache),
66+
subsChannel: make(chan SubsMessage, 10),
4367
}
4468
}
4569

4670
func (g *Gateway) Run(ctx context.Context) error {
4771
defer g.clientPool.Close()
72+
defer func() {
73+
// Clean up all session cache entries
74+
g.sessionCacheMu.Lock()
75+
g.sessionCache = make(map[*mcp.ServerSession]*ServerSessionCache)
76+
g.sessionCacheMu.Unlock()
77+
}()
4878

4979
start := time.Now()
5080

@@ -83,14 +113,33 @@ func (g *Gateway) Run(ctx context.Context) error {
83113
Name: "Docker AI MCP Gateway",
84114
Version: "2.0.1",
85115
}, &mcp.ServerOptions{
86-
SubscribeHandler: nil,
87-
UnsubscribeHandler: nil,
88-
RootsListChangedHandler: nil,
89-
CompletionHandler: nil,
90-
InitializedHandler: nil,
91-
HasPrompts: true,
92-
HasResources: true,
93-
HasTools: true,
116+
SubscribeHandler: func(ctx context.Context, ss *mcp.ServerSession, params *mcp.SubscribeParams) error {
117+
log("- Client subscribed to URI:", params.URI)
118+
// The MCP SDK doesn't provide ServerSession in SubscribeHandler because it already
119+
// keeps track of the mapping between ServerSession and subscribed resources in the Server
120+
// g.subsChannel <- SubsMessage{uri: params.URI, action: subscribe , ss: ss}
121+
return nil
122+
},
123+
UnsubscribeHandler: func(ctx context.Context, ss *mcp.ServerSession, params *mcp.UnsubscribeParams) error {
124+
log("- Client unsubscribed from URI:", params.URI)
125+
// The MCP SDK doesn't provide ServerSession in UnsubscribeHandler because it already
126+
// keeps track of the mapping ServerSession and subscribed resources in the Server
127+
// g.subsChannel <- SubsMessage{uri: params.URI, action: unsubscribe , ss: ss}
128+
return nil
129+
},
130+
RootsListChangedHandler: func(ctx context.Context, ss *mcp.ServerSession, params *mcp.RootsListChangedParams) {
131+
log("- Client roots list changed: ", ss.ID())
132+
g.ListRoots(ctx, ss)
133+
},
134+
CompletionHandler: nil,
135+
InitializedHandler: func(ctx context.Context, ss *mcp.ServerSession, params *mcp.InitializedParams) {
136+
log("- Client initialized: ", ss.ID())
137+
g.ListRoots(ctx, ss)
138+
},
139+
PageSize: 100,
140+
HasPrompts: true,
141+
HasResources: true,
142+
HasTools: true,
94143
})
95144

96145
// Add interceptor middleware to the server
@@ -235,3 +284,44 @@ func (g *Gateway) reloadConfiguration(ctx context.Context, configuration Configu
235284

236285
return nil
237286
}
287+
288+
// GetSessionCache returns the cached information for a server session
289+
func (g *Gateway) GetSessionCache(ss *mcp.ServerSession) *ServerSessionCache {
290+
g.sessionCacheMu.RLock()
291+
defer g.sessionCacheMu.RUnlock()
292+
return g.sessionCache[ss]
293+
}
294+
295+
// RemoveSessionCache removes the cached information for a server session
296+
func (g *Gateway) RemoveSessionCache(ss *mcp.ServerSession) {
297+
g.sessionCacheMu.Lock()
298+
defer g.sessionCacheMu.Unlock()
299+
delete(g.sessionCache, ss)
300+
}
301+
302+
// ListRoots checks if client supports Roots, gets them, and caches the result
303+
func (g *Gateway) ListRoots(ctx context.Context, ss *mcp.ServerSession) {
304+
// Check if client supports Roots and get them if available
305+
rootsResult, err := ss.ListRoots(ctx, nil)
306+
307+
g.sessionCacheMu.Lock()
308+
defer g.sessionCacheMu.Unlock()
309+
310+
// Get existing cache or create new one
311+
cache, exists := g.sessionCache[ss]
312+
if !exists {
313+
cache = &ServerSessionCache{}
314+
g.sessionCache[ss] = cache
315+
}
316+
317+
if err != nil {
318+
log("- Client does not support roots or error listing roots:", err)
319+
cache.Roots = nil
320+
} else {
321+
log("- Client supports roots, found", len(rootsResult.Roots), "roots")
322+
for _, root := range rootsResult.Roots {
323+
log(" - Root:", root.URI)
324+
}
325+
cache.Roots = rootsResult.Roots
326+
}
327+
}

cmd/docker-mcp/internal/gateway/transport.go

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ import (
88
"strings"
99
"sync"
1010

11+
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
1112
"github.com/modelcontextprotocol/go-sdk/mcp"
1213

1314
"github.com/docker/mcp-gateway/cmd/docker-mcp/internal/health"
1415
)
1516

1617
func (g *Gateway) startStdioServer(ctx context.Context, _ io.Reader, _ io.Writer) error {
17-
transport := mcp.NewStdioTransport()
18+
transport := newContextAwareStdioTransport(ctx)
1819
return g.mcpServer.Run(ctx, transport)
1920
}
2021

@@ -129,3 +130,81 @@ func healthHandler(state *health.State) http.HandlerFunc {
129130
}
130131
}
131132
}
133+
134+
// contextAwareStdioTransport is a custom stdio transport that handles context cancellation properly
135+
type contextAwareStdioTransport struct {
136+
ctx context.Context
137+
}
138+
139+
func newContextAwareStdioTransport(ctx context.Context) *contextAwareStdioTransport {
140+
return &contextAwareStdioTransport{ctx: ctx}
141+
}
142+
143+
func (t *contextAwareStdioTransport) Connect(ctx context.Context) (mcp.Connection, error) {
144+
// Create the original connection once
145+
transport := mcp.NewStdioTransport()
146+
originalConn, err := transport.Connect(ctx)
147+
if err != nil {
148+
return nil, err
149+
}
150+
151+
return newContextAwareStdioConn(t.ctx, originalConn), nil
152+
}
153+
154+
// contextAwareStdioConn wraps the original connection with context-aware reading
155+
type contextAwareStdioConn struct {
156+
ctx context.Context
157+
originalConn mcp.Connection
158+
}
159+
160+
func newContextAwareStdioConn(ctx context.Context, originalConn mcp.Connection) *contextAwareStdioConn {
161+
return &contextAwareStdioConn{
162+
ctx: ctx,
163+
originalConn: originalConn,
164+
}
165+
}
166+
167+
func (c *contextAwareStdioConn) SessionID() string {
168+
return c.originalConn.SessionID()
169+
}
170+
171+
func (c *contextAwareStdioConn) Read(ctx context.Context) (jsonrpc.Message, error) {
172+
// Create a channel to read from the original connection in a separate goroutine
173+
type result struct {
174+
msg jsonrpc.Message
175+
err error
176+
}
177+
178+
ch := make(chan result, 1)
179+
go func() {
180+
msg, err := c.originalConn.Read(context.Background())
181+
ch <- result{msg, err}
182+
}()
183+
184+
// Wait for either context cancellation or read completion
185+
select {
186+
case <-ctx.Done():
187+
return nil, ctx.Err()
188+
case <-c.ctx.Done():
189+
return nil, c.ctx.Err()
190+
case res := <-ch:
191+
return res.msg, res.err
192+
}
193+
}
194+
195+
func (c *contextAwareStdioConn) Write(ctx context.Context, msg jsonrpc.Message) error {
196+
// Check context first
197+
select {
198+
case <-ctx.Done():
199+
return ctx.Err()
200+
case <-c.ctx.Done():
201+
return c.ctx.Err()
202+
default:
203+
}
204+
205+
return c.originalConn.Write(ctx, msg)
206+
}
207+
208+
func (c *contextAwareStdioConn) Close() error {
209+
return c.originalConn.Close()
210+
}

cmd/docker-mcp/internal/mcp/mcp_client.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@ import (
99

1010
// Client interface wraps the official MCP SDK client with our legacy interface
1111
type Client interface {
12-
Initialize(ctx context.Context, params *mcp.InitializeParams, debug bool, serverSession *mcp.ServerSession) error
12+
Initialize(ctx context.Context, params *mcp.InitializeParams, debug bool, serverSession *mcp.ServerSession, server *mcp.Server) error
1313
Session() *mcp.ClientSession
1414
}
1515

16-
func stdioNotifications(serverSession *mcp.ServerSession) *mcp.ClientOptions {
16+
func notifications(serverSession *mcp.ServerSession, server *mcp.Server) *mcp.ClientOptions {
1717
return &mcp.ClientOptions{
18+
ResourceUpdatedHandler: func(ctx context.Context, _ *mcp.ClientSession, params *mcp.ResourceUpdatedNotificationParams) {
19+
if server != nil {
20+
_ = server.ResourceUpdated(ctx, params)
21+
}
22+
},
1823
CreateMessageHandler: func(_ context.Context, _ *mcp.ClientSession, _ *mcp.CreateMessageParams) (*mcp.CreateMessageResult, error) {
1924
// Handle create messages if needed
2025
return nil, fmt.Errorf("create messages not supported")

cmd/docker-mcp/internal/mcp/remote.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func NewRemoteMCPClient(config catalog.ServerConfig) Client {
2525
}
2626
}
2727

28-
func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeParams, _ bool, _ *mcp.ServerSession) error {
28+
func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeParams, _ bool, _ *mcp.ServerSession, _ *mcp.Server) error {
2929
if c.initialized.Load() {
3030
return fmt.Errorf("client already initialized")
3131
}

0 commit comments

Comments
 (0)