Skip to content

Commit 214dd16

Browse files
committed
Add central mode
Signed-off-by: David Gageot <[email protected]>
1 parent a12db88 commit 214dd16

File tree

5 files changed

+146
-52
lines changed

5 files changed

+146
-52
lines changed

cmd/docker-mcp/commands/gateway.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ func gatewayCommand(docker docker.Client) *cobra.Command {
6060
Short: "Run the gateway",
6161
Args: cobra.NoArgs,
6262
RunE: func(cmd *cobra.Command, _ []string) error {
63+
if options.Static {
64+
options.Watch = false
65+
}
66+
67+
if options.Central {
68+
options.Watch = false
69+
options.Transport = "streaming"
70+
}
71+
6372
if options.Transport == "stdio" {
6473
if options.Port != 0 {
6574
return errors.New("cannot use --port with --transport=stdio")
@@ -68,10 +77,6 @@ func gatewayCommand(docker docker.Client) *cobra.Command {
6877
options.Port = 8811
6978
}
7079

71-
if options.Static && options.Watch {
72-
return errors.New("cannot use --static with --watch")
73-
}
74-
7580
// Append additional catalogs to the main catalog path
7681
options.CatalogPath = append(options.CatalogPath, additionalCatalogs...)
7782
options.RegistryPath = append(options.RegistryPath, additionalRegistries...)
@@ -106,6 +111,10 @@ func gatewayCommand(docker docker.Client) *cobra.Command {
106111
runCmd.Flags().StringVar(&options.Memory, "memory", options.Memory, "Memory allocated to each MCP Server (default is 2Gb)")
107112
runCmd.Flags().BoolVar(&options.Static, "static", options.Static, "Enable static mode (aka pre-started servers)")
108113

114+
// Very experimental features
115+
runCmd.Flags().BoolVar(&options.Central, "central", options.Central, "In central mode, clients tell us which servers to enable")
116+
_ = runCmd.Flags().MarkHidden("central")
117+
109118
cmd.AddCommand(runCmd)
110119

111120
return cmd

cmd/docker-mcp/internal/gateway/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ type Options struct {
2626
Cpus int
2727
Memory string
2828
Static bool
29+
Central bool
2930
}

cmd/docker-mcp/internal/gateway/configuration.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ type FileBasedConfiguration struct {
9292
ConfigPath []string
9393
SecretsPath string // Optional, if not set, use Docker Desktop's secrets API
9494
Watch bool
95+
Central bool
9596

9697
docker docker.Client
9798
}
@@ -189,16 +190,17 @@ func (c *FileBasedConfiguration) readOnce(ctx context.Context) (Configuration, e
189190
log("- Reading configuration...")
190191

191192
var serverNames []string
193+
if !c.Central {
194+
if len(c.ServerNames) > 0 {
195+
serverNames = c.ServerNames
196+
} else {
197+
registryConfig, err := c.readRegistry(ctx)
198+
if err != nil {
199+
return Configuration{}, fmt.Errorf("reading registry: %w", err)
200+
}
192201

193-
if len(c.ServerNames) > 0 {
194-
serverNames = c.ServerNames
195-
} else {
196-
registryConfig, err := c.readRegistry(ctx)
197-
if err != nil {
198-
return Configuration{}, fmt.Errorf("reading registry: %w", err)
202+
serverNames = registryConfig.ServerNames()
199203
}
200-
201-
serverNames = registryConfig.ServerNames()
202204
}
203205

204206
mcpCatalog, err := c.readCatalog(ctx)
@@ -207,11 +209,13 @@ func (c *FileBasedConfiguration) readOnce(ctx context.Context) (Configuration, e
207209
}
208210
servers := mcpCatalog.Servers
209211

212+
// TODO(dga): Do we expect every server to have a config, in Central mode?
210213
serversConfig, err := c.readConfig(ctx)
211214
if err != nil {
212215
return Configuration{}, fmt.Errorf("reading config: %w", err)
213216
}
214217

218+
// TODO(dga): How do we know which secrets to read, in Central mode?
215219
var secrets map[string]string
216220
if c.SecretsPath == "docker-desktop" {
217221
secrets, err = c.readDockerDesktopSecrets(ctx, servers, serverNames)

cmd/docker-mcp/internal/gateway/run.go

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ type Gateway struct {
2121
docker docker.Client
2222
configurator Configurator
2323
clientPool *clientPool
24-
mcpServer *server.MCPServer
2524
health health.State
2625
}
2726

@@ -36,6 +35,7 @@ func NewGateway(config Config, docker docker.Client) *Gateway {
3635
ConfigPath: config.ConfigPath,
3736
SecretsPath: config.SecretsPath,
3837
Watch: config.Watch,
38+
Central: config.Central,
3939
docker: docker,
4040
},
4141
clientPool: newClientPool(config.Options, docker),
@@ -60,13 +60,48 @@ func (g *Gateway) Run(ctx context.Context) error {
6060
}
6161
}
6262

63+
// Build a list of interceptors.
64+
customInterceptors, err := interceptors.Parse(g.Interceptors)
65+
if err != nil {
66+
return fmt.Errorf("parsing interceptors: %w", err)
67+
}
68+
toolCallbacks := interceptors.Callbacks(g.LogCalls, g.BlockSecrets, customInterceptors)
69+
70+
// Create the MCP server.
71+
newMCPServer := func() *server.MCPServer {
72+
return server.NewMCPServer(
73+
"Docker AI MCP Gateway",
74+
"2.0.1",
75+
server.WithToolHandlerMiddleware(toolCallbacks),
76+
server.WithHooks(&server.Hooks{
77+
OnBeforeInitialize: []server.OnBeforeInitializeFunc{
78+
func(_ context.Context, id any, _ *mcp.InitializeRequest) {
79+
log("> Initializing MCP server with ID:", id)
80+
},
81+
},
82+
}),
83+
)
84+
}
85+
6386
// Read the configuration.
6487
configuration, configurationUpdates, stopConfigWatcher, err := g.configurator.Read(ctx)
6588
if err != nil {
6689
return err
6790
}
6891
defer func() { _ = stopConfigWatcher() }()
6992

93+
// Central mode.
94+
if g.Central {
95+
log("> Initialized in", time.Since(start))
96+
if g.DryRun {
97+
log("Dry run mode enabled, not starting the server.")
98+
return nil
99+
}
100+
101+
return g.startCentralStreamingServer(ctx, newMCPServer, ln, configuration)
102+
}
103+
mcpServer := newMCPServer()
104+
70105
// Which docker images are used?
71106
// Pull them and verify them if possible.
72107
if !g.Static {
@@ -84,27 +119,7 @@ func (g *Gateway) Run(ctx context.Context) error {
84119
}
85120
}
86121

87-
// Build a list of interceptors.
88-
customInterceptors, err := interceptors.Parse(g.Interceptors)
89-
if err != nil {
90-
return fmt.Errorf("parsing interceptors: %w", err)
91-
}
92-
toolCallbacks := interceptors.Callbacks(g.LogCalls, g.BlockSecrets, customInterceptors)
93-
94-
g.mcpServer = server.NewMCPServer(
95-
"Docker AI MCP Gateway",
96-
"2.0.1",
97-
server.WithToolHandlerMiddleware(toolCallbacks),
98-
server.WithHooks(&server.Hooks{
99-
OnBeforeInitialize: []server.OnBeforeInitializeFunc{
100-
func(_ context.Context, id any, _ *mcp.InitializeRequest) {
101-
log("> Initializing MCP server with ID:", id)
102-
},
103-
},
104-
}),
105-
)
106-
107-
if err := g.reloadConfiguration(ctx, configuration); err != nil {
122+
if err := g.reloadConfiguration(ctx, mcpServer, configuration, nil); err != nil {
108123
return fmt.Errorf("loading configuration: %w", err)
109124
}
110125

@@ -125,7 +140,7 @@ func (g *Gateway) Run(ctx context.Context) error {
125140
continue
126141
}
127142

128-
if err := g.reloadConfiguration(ctx, configuration); err != nil {
143+
if err := g.reloadConfiguration(ctx, mcpServer, configuration, nil); err != nil {
129144
logf("> Unable to list capabilities: %s", err)
130145
continue
131146
}
@@ -144,24 +159,26 @@ func (g *Gateway) Run(ctx context.Context) error {
144159
switch strings.ToLower(g.Transport) {
145160
case "stdio":
146161
log("> Start stdio server")
147-
return g.startStdioServer(ctx, os.Stdin, os.Stdout)
162+
return g.startStdioServer(ctx, mcpServer, os.Stdin, os.Stdout)
148163

149164
case "sse":
150165
log("> Start sse server on port", g.Port)
151-
return g.startSseServer(ctx, ln)
166+
return g.startSseServer(ctx, mcpServer, ln)
152167

153168
case "streaming":
154169
log("> Start streaming server on port", g.Port)
155-
return g.startStreamingServer(ctx, ln)
170+
return g.startStreamingServer(ctx, mcpServer, ln)
156171

157172
default:
158173
return fmt.Errorf("unknown transport %q, expected 'stdio', 'sse' or 'streaming", g.Transport)
159174
}
160175
}
161176

162-
func (g *Gateway) reloadConfiguration(ctx context.Context, configuration Configuration) error {
177+
func (g *Gateway) reloadConfiguration(ctx context.Context, mcpServer *server.MCPServer, configuration Configuration, serverNames []string) error {
163178
// Which servers are enabled in the registry.yaml?
164-
serverNames := configuration.ServerNames()
179+
if len(serverNames) == 0 {
180+
serverNames = configuration.ServerNames()
181+
}
165182
if len(serverNames) == 0 {
166183
log("- No server is enabled")
167184
} else {
@@ -179,12 +196,12 @@ func (g *Gateway) reloadConfiguration(ctx context.Context, configuration Configu
179196

180197
// Update the server's capabilities.
181198
g.health.SetUnhealthy()
182-
g.mcpServer.SetTools(capabilities.Tools...)
183-
g.mcpServer.SetPrompts(capabilities.Prompts...)
184-
g.mcpServer.SetResources(capabilities.Resources...)
185-
g.mcpServer.RemoveAllResourceTemplates()
199+
mcpServer.SetTools(capabilities.Tools...)
200+
mcpServer.SetPrompts(capabilities.Prompts...)
201+
mcpServer.SetResources(capabilities.Resources...)
202+
mcpServer.RemoveAllResourceTemplates()
186203
for _, v := range capabilities.ResourceTemplates {
187-
g.mcpServer.AddResourceTemplate(v.ResourceTemplate, v.Handler)
204+
mcpServer.AddResourceTemplate(v.ResourceTemplate, v.Handler)
188205
}
189206
g.health.SetHealthy()
190207

cmd/docker-mcp/internal/gateway/transport.go

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,47 @@ import (
55
"io"
66
"net"
77
"net/http"
8+
"strings"
9+
"sync"
810

911
"github.com/mark3labs/mcp-go/server"
1012
)
1113

12-
func (g *Gateway) startStdioServer(ctx context.Context, stdin io.Reader, stdout io.Writer) error {
13-
return server.NewStdioServer(g.mcpServer).Listen(ctx, stdin, stdout)
14+
func (g *Gateway) startStdioServer(ctx context.Context, mcpServer *server.MCPServer, stdin io.Reader, stdout io.Writer) error {
15+
return server.NewStdioServer(mcpServer).Listen(ctx, stdin, stdout)
1416
}
1517

16-
func (g *Gateway) startSseServer(ctx context.Context, ln net.Listener) error {
18+
func (g *Gateway) startSseServer(ctx context.Context, mcpServer *server.MCPServer, ln net.Listener) error {
1719
mux := http.NewServeMux()
18-
sseServer := server.NewSSEServer(g.mcpServer)
20+
sseServer := server.NewSSEServer(mcpServer)
21+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
22+
http.Redirect(w, r, "/sse", http.StatusTemporaryRedirect)
23+
})
1924
mux.Handle("/sse", sseServer.SSEHandler())
2025
mux.Handle("/message", sseServer.MessageHandler())
26+
mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
27+
if g.health.IsHealthy() {
28+
w.WriteHeader(http.StatusOK)
29+
} else {
30+
w.WriteHeader(http.StatusServiceUnavailable)
31+
}
32+
})
33+
httpServer := &http.Server{
34+
Handler: mux,
35+
}
36+
go func() {
37+
<-ctx.Done()
38+
ln.Close()
39+
}()
40+
return httpServer.Serve(ln)
41+
}
42+
43+
func (g *Gateway) startStreamingServer(ctx context.Context, mcpServer *server.MCPServer, ln net.Listener) error {
44+
mux := http.NewServeMux()
2145
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
22-
http.Redirect(w, r, "/sse", http.StatusTemporaryRedirect)
46+
http.Redirect(w, r, "/mcp", http.StatusTemporaryRedirect)
2347
})
48+
mux.Handle("/mcp", server.NewStreamableHTTPServer(mcpServer))
2449
mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
2550
if g.health.IsHealthy() {
2651
w.WriteHeader(http.StatusOK)
@@ -31,19 +56,48 @@ func (g *Gateway) startSseServer(ctx context.Context, ln net.Listener) error {
3156
httpServer := &http.Server{
3257
Handler: mux,
3358
}
59+
3460
go func() {
3561
<-ctx.Done()
3662
ln.Close()
3763
}()
3864
return httpServer.Serve(ln)
3965
}
4066

41-
func (g *Gateway) startStreamingServer(ctx context.Context, ln net.Listener) error {
67+
func (g *Gateway) startCentralStreamingServer(ctx context.Context, newMCPServer func() *server.MCPServer, ln net.Listener, configuration Configuration) error {
68+
var lock sync.Mutex
69+
handlersPerSelectionOfServers := map[string]*server.StreamableHTTPServer{}
70+
4271
mux := http.NewServeMux()
43-
mux.Handle("/mcp", server.NewStreamableHTTPServer(g.mcpServer))
4472
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
4573
http.Redirect(w, r, "/mcp", http.StatusTemporaryRedirect)
4674
})
75+
mux.HandleFunc("/mcp", func(w http.ResponseWriter, r *http.Request) {
76+
serverNames := r.Header.Get("x-mcp-servers")
77+
if len(serverNames) == 0 {
78+
_, _ = w.Write([]byte("No server names provided in the request header 'x-mcp-servers'"))
79+
w.WriteHeader(http.StatusBadRequest)
80+
return
81+
}
82+
83+
lock.Lock()
84+
handler := handlersPerSelectionOfServers[serverNames]
85+
if handler == nil {
86+
mcpServer := newMCPServer()
87+
if err := g.reloadConfiguration(ctx, mcpServer, configuration, parseServerNames(serverNames)); err != nil {
88+
lock.Unlock()
89+
_, _ = w.Write([]byte("Failed to reload configuration"))
90+
w.WriteHeader(http.StatusInternalServerError)
91+
return
92+
}
93+
handler = server.NewStreamableHTTPServer(mcpServer)
94+
handlersPerSelectionOfServers[serverNames] = handler
95+
}
96+
lock.Unlock()
97+
98+
handler.ServeHTTP(w, r)
99+
})
100+
47101
mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
48102
if g.health.IsHealthy() {
49103
w.WriteHeader(http.StatusOK)
@@ -59,5 +113,14 @@ func (g *Gateway) startStreamingServer(ctx context.Context, ln net.Listener) err
59113
<-ctx.Done()
60114
ln.Close()
61115
}()
116+
g.health.SetHealthy()
62117
return httpServer.Serve(ln)
63118
}
119+
120+
func parseServerNames(serverNames string) []string {
121+
var names []string
122+
for _, name := range strings.Split(serverNames, ",") {
123+
names = append(names, strings.TrimSpace(name))
124+
}
125+
return names
126+
}

0 commit comments

Comments
 (0)