Skip to content

Commit afdad5c

Browse files
Add Notify handler pass throughs
1 parent 589c69e commit afdad5c

33 files changed

+1506
-624
lines changed

cmd/docker-mcp/internal/gateway/clientpool.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ type clientPool struct {
3131
docker docker.Client
3232
}
3333

34+
type clientConfig struct {
35+
readOnly *bool
36+
serverSession *mcp.ServerSession
37+
}
38+
3439
func newClientPool(options Options, docker docker.Client) *clientPool {
3540
return &clientPool{
3641
Options: options,
@@ -39,7 +44,7 @@ func newClientPool(options Options, docker docker.Client) *clientPool {
3944
}
4045
}
4146

42-
func (cp *clientPool) AcquireClient(ctx context.Context, serverConfig catalog.ServerConfig, readOnly *bool) (mcpclient.Client, error) {
47+
func (cp *clientPool) AcquireClient(ctx context.Context, serverConfig catalog.ServerConfig, config *clientConfig) (mcpclient.Client, error) {
4348
var getter *clientGetter
4449

4550
// Check if client is kept, can be returned immediately
@@ -54,7 +59,7 @@ func (cp *clientPool) AcquireClient(ctx context.Context, serverConfig catalog.Se
5459

5560
// No client found, create a new one
5661
if getter == nil {
57-
getter = newClientGetter(serverConfig, cp, readOnly)
62+
getter = newClientGetter(serverConfig, cp, config)
5863

5964
// If the client is long running, save it for later
6065
if serverConfig.Spec.LongLived || cp.LongLived {
@@ -304,14 +309,15 @@ type clientGetter struct {
304309

305310
serverConfig catalog.ServerConfig
306311
cp *clientPool
307-
readOnly *bool
312+
313+
clientConfig *clientConfig
308314
}
309315

310-
func newClientGetter(serverConfig catalog.ServerConfig, cp *clientPool, readOnly *bool) *clientGetter {
316+
func newClientGetter(serverConfig catalog.ServerConfig, cp *clientPool, config *clientConfig) *clientGetter {
311317
return &clientGetter{
312318
serverConfig: serverConfig,
313319
cp: cp,
314-
readOnly: readOnly,
320+
clientConfig: config,
315321
}
316322
}
317323

@@ -343,7 +349,9 @@ func (cg *clientGetter) GetClient(ctx context.Context) (mcpclient.Client, error)
343349
}
344350

345351
image := cg.serverConfig.Spec.Image
346-
args, env := cg.cp.argsAndEnv(cg.serverConfig, cg.readOnly, targetConfig)
352+
var readOnly *bool
353+
if cg.clientConfig != nil { readOnly = cg.clientConfig.readOnly} else { readOnly = nil }
354+
args, env := cg.cp.argsAndEnv(cg.serverConfig, readOnly, targetConfig)
347355

348356
command := expandEnvList(eval.EvaluateList(cg.serverConfig.Spec.Command, cg.serverConfig.Config), env)
349357
if len(command) == 0 {
@@ -370,7 +378,9 @@ func (cg *clientGetter) GetClient(ctx context.Context) (mcpclient.Client, error)
370378

371379
// Use the original context instead of creating a timeout context
372380
// to avoid cancellation issues
373-
if _, err := client.Initialize(ctx, initParams, cg.cp.Verbose); err != nil {
381+
var ss *mcp.ServerSession
382+
if cg.clientConfig != nil { ss = cg.clientConfig.serverSession} else { ss = nil }
383+
if _, err := client.Initialize(ctx, initParams, cg.cp.Verbose, ss); err != nil {
374384
return nil, err
375385
}
376386

cmd/docker-mcp/internal/gateway/clientpool_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func TestStdioClientInitialization(t *testing.T) {
216216
defer cancel()
217217

218218
// Test client acquisition and initialization
219-
client, err := clientPool.AcquireClient(ctx, serverConfig, boolPtr(false))
219+
client, err := clientPool.AcquireClient(ctx, serverConfig, &clientConfig{readOnly: boolPtr(false)})
220220
if err != nil {
221221
t.Fatalf("Failed to acquire client: %v", err)
222222
}

cmd/docker-mcp/internal/gateway/handlers.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ import (
88
"github.com/docker/mcp-gateway/cmd/docker-mcp/internal/catalog"
99
)
1010

11+
func getClientConfig(readOnlyHint *bool, ss *mcp.ServerSession) *clientConfig {
12+
return &clientConfig{readOnly: readOnlyHint, serverSession: ss}
13+
}
14+
1115
func (g *Gateway) mcpToolHandler(tool catalog.Tool) mcp.ToolHandler {
1216
return func(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[map[string]any]) (*mcp.CallToolResultFor[any], error) {
1317
// Convert to the generic version for our internal methods
@@ -34,7 +38,7 @@ func (g *Gateway) mcpServerToolHandler(serverConfig catalog.ServerConfig, annota
3438
Arguments: params.Arguments,
3539
}
3640

37-
client, err := g.clientPool.AcquireClient(ctx, serverConfig, readOnlyHint)
41+
client, err := g.clientPool.AcquireClient(ctx, serverConfig, getClientConfig( readOnlyHint, ss))
3842
if err != nil {
3943
return nil, err
4044
}
@@ -46,7 +50,7 @@ func (g *Gateway) mcpServerToolHandler(serverConfig catalog.ServerConfig, annota
4650

4751
func (g *Gateway) mcpServerPromptHandler(serverConfig catalog.ServerConfig) mcp.PromptHandler {
4852
return func(ctx context.Context, ss *mcp.ServerSession, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) {
49-
client, err := g.clientPool.AcquireClient(ctx, serverConfig, nil)
53+
client, err := g.clientPool.AcquireClient(ctx, serverConfig, getClientConfig(nil, ss))
5054
if err != nil {
5155
return nil, err
5256
}
@@ -58,7 +62,7 @@ func (g *Gateway) mcpServerPromptHandler(serverConfig catalog.ServerConfig) mcp.
5862

5963
func (g *Gateway) mcpServerResourceHandler(serverConfig catalog.ServerConfig) mcp.ResourceHandler {
6064
return func(ctx context.Context, ss *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) {
61-
client, err := g.clientPool.AcquireClient(ctx, serverConfig, nil)
65+
client, err := g.clientPool.AcquireClient(ctx, serverConfig, getClientConfig(nil, ss))
6266
if err != nil {
6367
return nil, err
6468
}
@@ -70,7 +74,7 @@ func (g *Gateway) mcpServerResourceHandler(serverConfig catalog.ServerConfig) mc
7074

7175
func (g *Gateway) mcpServerResourceTemplateHandler(serverConfig catalog.ServerConfig) mcp.ResourceHandler {
7276
return func(ctx context.Context, ss *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) {
73-
client, err := g.clientPool.AcquireClient(ctx, serverConfig, nil)
77+
client, err := g.clientPool.AcquireClient(ctx, serverConfig, getClientConfig(nil, ss))
7478
if err != nil {
7579
return nil, err
7680
}

cmd/docker-mcp/internal/gateway/run.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,16 @@ func (g *Gateway) Run(ctx context.Context) error {
8282
g.mcpServer = mcp.NewServer(&mcp.Implementation{
8383
Name: "Docker AI MCP Gateway",
8484
Version: "2.0.1",
85-
}, nil)
85+
}, &mcp.ServerOptions{
86+
SubscribeHandler: nil,
87+
UnsubscribeHandler: nil,
88+
RootsListChangedHandler: nil,
89+
CompletionHandler: nil,
90+
InitializedHandler: nil,
91+
HasPrompts: true,
92+
HasResources: true,
93+
HasTools: true,
94+
})
8695

8796
// Add interceptor middleware to the server
8897
middlewares := interceptors.Callbacks(g.LogCalls, g.BlockSecrets, parsedInterceptors)
@@ -194,10 +203,12 @@ func (g *Gateway) reloadConfiguration(ctx context.Context, configuration Configu
194203
}
195204
log(">", len(capabilities.Tools), "tools listed in", time.Since(startList))
196205

206+
// Update capabilities
197207
// Clear existing capabilities and register new ones
198208
// Note: The new SDK doesn't have bulk set methods, so we register individually
209+
199210
for _, tool := range capabilities.Tools {
200-
mcp.AddTool(g.mcpServer, tool.Tool, tool.Handler)
211+
g.mcpServer.AddTool(tool.Tool, tool.Handler)
201212
}
202213

203214
for _, prompt := range capabilities.Prompts {

cmd/docker-mcp/internal/mcp/mcp_client.go

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ package mcp
22

33
import (
44
"context"
5+
"fmt"
6+
"slices"
57

68
"github.com/modelcontextprotocol/go-sdk/mcp"
79
)
810

911
// Client interface wraps the official MCP SDK client with our legacy interface
1012
type Client interface {
11-
Initialize(ctx context.Context, params *mcp.InitializeParams, debug bool) (*mcp.InitializeResult, error)
13+
Initialize(ctx context.Context, params *mcp.InitializeParams, debug bool, serverSession *mcp.ServerSession) (*mcp.InitializeResult, error)
1214
ListTools(ctx context.Context, params *mcp.ListToolsParams) (*mcp.ListToolsResult, error)
1315
ListPrompts(ctx context.Context, params *mcp.ListPromptsParams) (*mcp.ListPromptsResult, error)
1416
ListResources(ctx context.Context, params *mcp.ListResourcesParams) (*mcp.ListResourcesResult, error)
@@ -18,3 +20,68 @@ type Client interface {
1820
ReadResource(ctx context.Context, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error)
1921
Close() error
2022
}
23+
24+
func allSessions(server *mcp.Server, callback func(session *mcp.ServerSession)) {
25+
for session := range server.Sessions() {
26+
callback(session)
27+
}
28+
}
29+
30+
func serverNotifications(server *mcp.Server) *mcp.ClientOptions {
31+
return &mcp.ClientOptions{
32+
CreateMessageHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.CreateMessageParams) (*mcp.CreateMessageResult, error) {
33+
// Handle create messages if needed
34+
return nil, fmt.Errorf("create messages not supported")
35+
},
36+
ToolListChangedHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.ToolListChangedParams) {
37+
// Handle tool list changes if needed
38+
if server != nil {
39+
sessions := slices.Collect(server.Sessions())
40+
mcp.NotifySessions(sessions, "notifications/tools/list_changed", params)
41+
}
42+
},
43+
ResourceListChangedHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.ResourceListChangedParams) {
44+
if server != nil {
45+
sessions := slices.Collect(server.Sessions())
46+
mcp.NotifySessions(sessions, "notifications/resources/list_changed", params)
47+
}
48+
},
49+
PromptListChangedHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.PromptListChangedParams) {
50+
if server != nil {
51+
sessions := slices.Collect(server.Sessions())
52+
mcp.NotifySessions(sessions, "notifications/prompts/list_changed", params)
53+
}
54+
},
55+
ProgressNotificationHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.ProgressNotificationParams) {
56+
allSessions(server, func (session *mcp.ServerSession) {session.NotifyProgress(ctx, params)})
57+
},
58+
LoggingMessageHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.LoggingMessageParams) {
59+
allSessions(server, func (session *mcp.ServerSession) {session.Log(ctx, params)})
60+
},
61+
}
62+
}
63+
64+
func stdioNotifications(serverSession *mcp.ServerSession) *mcp.ClientOptions {
65+
return &mcp.ClientOptions{
66+
CreateMessageHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.CreateMessageParams) (*mcp.CreateMessageResult, error) {
67+
// Handle create messages if needed
68+
return nil, fmt.Errorf("create messages not supported")
69+
},
70+
ToolListChangedHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.ToolListChangedParams) {
71+
mcp.HandleNotify(ctx, serverSession, "notifications/tools/list_changed", params)
72+
},
73+
ResourceListChangedHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.ResourceListChangedParams) {
74+
mcp.HandleNotify(ctx, serverSession, "notifications/resources/list_changed", params)
75+
},
76+
PromptListChangedHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.PromptListChangedParams) {
77+
mcp.HandleNotify(ctx, serverSession, "notifications/prompts/list_changed", params)
78+
},
79+
ProgressNotificationHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.ProgressNotificationParams) {
80+
serverSession.NotifyProgress(ctx, params)
81+
},
82+
LoggingMessageHandler: func(ctx context.Context, session *mcp.ClientSession, params *mcp.LoggingMessageParams) {
83+
serverSession.Log(ctx, params)
84+
},
85+
}
86+
}
87+

cmd/docker-mcp/internal/mcp/remote.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func NewRemoteMCPClient(config catalog.ServerConfig) Client {
2525
}
2626
}
2727

28-
func (c *remoteMCPClient) Initialize(ctx context.Context, params *mcp.InitializeParams, _ bool) (*mcp.InitializeResult, error) {
28+
func (c *remoteMCPClient) Initialize(ctx context.Context, params *mcp.InitializeParams, _ bool, _ *mcp.ServerSession) (*mcp.InitializeResult, error) {
2929
if c.initialized.Load() {
3030
return nil, fmt.Errorf("client already initialized")
3131
}

cmd/docker-mcp/internal/mcp/stdio.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func NewStdioCmdClient(name string, command string, env []string, args ...string
3131
}
3232
}
3333

34-
func (c *stdioMCPClient) Initialize(ctx context.Context, params *mcp.InitializeParams, debug bool) (*mcp.InitializeResult, error) {
34+
func (c *stdioMCPClient) Initialize(ctx context.Context, params *mcp.InitializeParams, debug bool, s *mcp.ServerSession) (*mcp.InitializeResult, error) {
3535
if c.initialized.Load() {
3636
return nil, fmt.Errorf("client already initialized")
3737
}
@@ -47,7 +47,7 @@ func (c *stdioMCPClient) Initialize(ctx context.Context, params *mcp.InitializeP
4747
c.client = mcp.NewClient(&mcp.Implementation{
4848
Name: "docker-mcp-gateway",
4949
Version: "1.0.0",
50-
}, nil)
50+
}, stdioNotifications(s))
5151

5252
session, err := c.client.Connect(ctx, transport)
5353
if err != nil {

cmd/docker-mcp/internal/mcp/stdio_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func TestStdioClientInitializeAndListTools(t *testing.T) {
4040
},
4141
}
4242

43-
result, err := client.Initialize(ctx, initParams, true) // verbose = true for debugging
43+
result, err := client.Initialize(ctx, initParams, true, nil) // verbose = true for debugging
4444
require.NoError(t, err, "Failed to initialize stdio client")
4545
require.NotNil(t, result, "Initialize result should not be nil")
4646

cmd/docker-mcp/long_lived_integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func newTestGatewayClient(t *testing.T, args []string) mcpclient.Client {
6767
},
6868
}
6969

70-
_, err := c.Initialize(t.Context(), initParams, false)
70+
_, err := c.Initialize(t.Context(), initParams, false, nil)
7171
require.NoError(t, err)
7272

7373
return c

cmd/docker-mcp/tools/start.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func start(ctx context.Context, version string, gatewayArgs []string, debug bool
3232
ctx, cancel := context.WithTimeout(ctx, 60*time.Second)
3333
defer cancel()
3434

35-
if _, err := c.Initialize(ctx, initParams, debug); err != nil {
35+
if _, err := c.Initialize(ctx, initParams, debug, nil); err != nil {
3636
return nil, fmt.Errorf("initializing: %w", err)
3737
}
3838

0 commit comments

Comments
 (0)