Skip to content

Commit 923fa3a

Browse files
sd2kclaude
andauthored
fix: improve graceful shutdown handling for all transport modes (#271)
Co-authored-by: Claude <[email protected]>
1 parent cc35912 commit 923fa3a

File tree

1 file changed

+87
-10
lines changed

1 file changed

+87
-10
lines changed

cmd/mcp-grafana/main.go

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@ package main
22

33
import (
44
"context"
5+
"errors"
56
"flag"
67
"fmt"
78
"log/slog"
9+
"net/http"
810
"os"
11+
"os/signal"
912
"slices"
1013
"strings"
14+
"syscall"
15+
"time"
1116

1217
"github.com/mark3labs/mcp-go/server"
1318

@@ -124,25 +129,99 @@ func (tc *tlsConfig) addFlags() {
124129
flag.StringVar(&tc.keyFile, "server.tls-key-file", "", "Path to TLS private key file for server HTTPS (required for TLS)")
125130
}
126131

132+
// httpServer represents a server with Start and Shutdown methods
133+
type httpServer interface {
134+
Start(addr string) error
135+
Shutdown(ctx context.Context) error
136+
}
137+
138+
// runHTTPServer handles the common logic for running HTTP-based servers
139+
func runHTTPServer(ctx context.Context, srv httpServer, addr, transportName string) error {
140+
// Start server in a goroutine
141+
serverErr := make(chan error, 1)
142+
go func() {
143+
if err := srv.Start(addr); err != nil {
144+
serverErr <- err
145+
}
146+
close(serverErr)
147+
}()
148+
149+
// Wait for either server error or shutdown signal
150+
select {
151+
case err := <-serverErr:
152+
return err
153+
case <-ctx.Done():
154+
slog.Info(fmt.Sprintf("%s server shutting down...", transportName))
155+
156+
// Create a timeout context for shutdown
157+
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
158+
defer shutdownCancel()
159+
160+
if err := srv.Shutdown(shutdownCtx); err != nil {
161+
return fmt.Errorf("shutdown error: %v", err)
162+
}
163+
164+
// Wait for server to finish
165+
select {
166+
case err := <-serverErr:
167+
// http.ErrServerClosed is expected when shutting down
168+
if err != nil && !errors.Is(err, http.ErrServerClosed) {
169+
return fmt.Errorf("server error during shutdown: %v", err)
170+
}
171+
case <-shutdownCtx.Done():
172+
slog.Warn(fmt.Sprintf("%s server did not stop gracefully within timeout", transportName))
173+
}
174+
}
175+
176+
return nil
177+
}
178+
127179
func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt disabledTools, gc mcpgrafana.GrafanaConfig, tls tlsConfig) error {
128180
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel})))
129181
s := newServer(dt)
130182

183+
// Create a context that will be cancelled on shutdown
184+
ctx, cancel := context.WithCancel(context.Background())
185+
defer cancel()
186+
187+
// Set up signal handling for graceful shutdown
188+
sigChan := make(chan os.Signal, 1)
189+
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
190+
defer signal.Stop(sigChan)
191+
192+
// Handle shutdown signals
193+
go func() {
194+
<-sigChan
195+
slog.Info("Received shutdown signal")
196+
cancel()
197+
198+
// For stdio, close stdin to unblock the Listen call
199+
if transport == "stdio" {
200+
_ = os.Stdin.Close()
201+
}
202+
}()
203+
204+
// Start the appropriate server based on transport
131205
switch transport {
132206
case "stdio":
133207
srv := server.NewStdioServer(s)
134208
srv.SetContextFunc(mcpgrafana.ComposedStdioContextFunc(gc))
135209
slog.Info("Starting Grafana MCP server using stdio transport", "version", mcpgrafana.Version())
136-
return srv.Listen(context.Background(), os.Stdin, os.Stdout)
210+
211+
err := srv.Listen(ctx, os.Stdin, os.Stdout)
212+
if err != nil && err != context.Canceled {
213+
return fmt.Errorf("server error: %v", err)
214+
}
215+
return nil
216+
137217
case "sse":
138218
srv := server.NewSSEServer(s,
139219
server.WithSSEContextFunc(mcpgrafana.ComposedSSEContextFunc(gc)),
140220
server.WithStaticBasePath(basePath),
141221
)
142-
slog.Info("Starting Grafana MCP server using SSE transport", "version", mcpgrafana.Version(), "address", addr, "basePath", basePath)
143-
if err := srv.Start(addr); err != nil {
144-
return fmt.Errorf("server error: %v", err)
145-
}
222+
slog.Info("Starting Grafana MCP server using SSE transport",
223+
"version", mcpgrafana.Version(), "address", addr, "basePath", basePath)
224+
return runHTTPServer(ctx, srv, addr, "SSE")
146225
case "streamable-http":
147226
opts := []server.StreamableHTTPOption{
148227
server.WithHTTPContextFunc(mcpgrafana.ComposedHTTPContextFunc(gc)),
@@ -153,17 +232,15 @@ func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt
153232
opts = append(opts, server.WithTLSCert(tls.certFile, tls.keyFile))
154233
}
155234
srv := server.NewStreamableHTTPServer(s, opts...)
156-
slog.Info("Starting Grafana MCP server using StreamableHTTP transport", "version", mcpgrafana.Version(), "address", addr, "endpointPath", endpointPath)
157-
if err := srv.Start(addr); err != nil {
158-
return fmt.Errorf("server error: %v", err)
159-
}
235+
slog.Info("Starting Grafana MCP server using StreamableHTTP transport",
236+
"version", mcpgrafana.Version(), "address", addr, "endpointPath", endpointPath)
237+
return runHTTPServer(ctx, srv, addr, "StreamableHTTP")
160238
default:
161239
return fmt.Errorf(
162240
"invalid transport type: %s. Must be 'stdio', 'sse' or 'streamable-http'",
163241
transport,
164242
)
165243
}
166-
return nil
167244
}
168245

169246
func main() {

0 commit comments

Comments
 (0)