@@ -2,12 +2,17 @@ package main
22
33import (
44 "context"
5+ "errors"
56 "flag"
67 "fmt"
78 "log/slog"
9+ "net/http"
810 "os"
11+ "os/signal"
912 "slices"
1013 "strings"
14+ "syscall"
15+ "time"
1116
1217 "github.com/mark3labs/mcp-go/server"
1318
@@ -124,25 +129,99 @@ func (tc *tlsConfig) addFlags() {
124129 flag .StringVar (& tc .keyFile , "server.tls-key-file" , "" , "Path to TLS private key file for server HTTPS (required for TLS)" )
125130}
126131
132+ // httpServer represents a server with Start and Shutdown methods
133+ type httpServer interface {
134+ Start (addr string ) error
135+ Shutdown (ctx context.Context ) error
136+ }
137+
138+ // runHTTPServer handles the common logic for running HTTP-based servers
139+ func runHTTPServer (ctx context.Context , srv httpServer , addr , transportName string ) error {
140+ // Start server in a goroutine
141+ serverErr := make (chan error , 1 )
142+ go func () {
143+ if err := srv .Start (addr ); err != nil {
144+ serverErr <- err
145+ }
146+ close (serverErr )
147+ }()
148+
149+ // Wait for either server error or shutdown signal
150+ select {
151+ case err := <- serverErr :
152+ return err
153+ case <- ctx .Done ():
154+ slog .Info (fmt .Sprintf ("%s server shutting down..." , transportName ))
155+
156+ // Create a timeout context for shutdown
157+ shutdownCtx , shutdownCancel := context .WithTimeout (context .Background (), 5 * time .Second )
158+ defer shutdownCancel ()
159+
160+ if err := srv .Shutdown (shutdownCtx ); err != nil {
161+ return fmt .Errorf ("shutdown error: %v" , err )
162+ }
163+
164+ // Wait for server to finish
165+ select {
166+ case err := <- serverErr :
167+ // http.ErrServerClosed is expected when shutting down
168+ if err != nil && ! errors .Is (err , http .ErrServerClosed ) {
169+ return fmt .Errorf ("server error during shutdown: %v" , err )
170+ }
171+ case <- shutdownCtx .Done ():
172+ slog .Warn (fmt .Sprintf ("%s server did not stop gracefully within timeout" , transportName ))
173+ }
174+ }
175+
176+ return nil
177+ }
178+
127179func run (transport , addr , basePath , endpointPath string , logLevel slog.Level , dt disabledTools , gc mcpgrafana.GrafanaConfig , tls tlsConfig ) error {
128180 slog .SetDefault (slog .New (slog .NewTextHandler (os .Stderr , & slog.HandlerOptions {Level : logLevel })))
129181 s := newServer (dt )
130182
183+ // Create a context that will be cancelled on shutdown
184+ ctx , cancel := context .WithCancel (context .Background ())
185+ defer cancel ()
186+
187+ // Set up signal handling for graceful shutdown
188+ sigChan := make (chan os.Signal , 1 )
189+ signal .Notify (sigChan , os .Interrupt , syscall .SIGTERM )
190+ defer signal .Stop (sigChan )
191+
192+ // Handle shutdown signals
193+ go func () {
194+ <- sigChan
195+ slog .Info ("Received shutdown signal" )
196+ cancel ()
197+
198+ // For stdio, close stdin to unblock the Listen call
199+ if transport == "stdio" {
200+ _ = os .Stdin .Close ()
201+ }
202+ }()
203+
204+ // Start the appropriate server based on transport
131205 switch transport {
132206 case "stdio" :
133207 srv := server .NewStdioServer (s )
134208 srv .SetContextFunc (mcpgrafana .ComposedStdioContextFunc (gc ))
135209 slog .Info ("Starting Grafana MCP server using stdio transport" , "version" , mcpgrafana .Version ())
136- return srv .Listen (context .Background (), os .Stdin , os .Stdout )
210+
211+ err := srv .Listen (ctx , os .Stdin , os .Stdout )
212+ if err != nil && err != context .Canceled {
213+ return fmt .Errorf ("server error: %v" , err )
214+ }
215+ return nil
216+
137217 case "sse" :
138218 srv := server .NewSSEServer (s ,
139219 server .WithSSEContextFunc (mcpgrafana .ComposedSSEContextFunc (gc )),
140220 server .WithStaticBasePath (basePath ),
141221 )
142- slog .Info ("Starting Grafana MCP server using SSE transport" , "version" , mcpgrafana .Version (), "address" , addr , "basePath" , basePath )
143- if err := srv .Start (addr ); err != nil {
144- return fmt .Errorf ("server error: %v" , err )
145- }
222+ slog .Info ("Starting Grafana MCP server using SSE transport" ,
223+ "version" , mcpgrafana .Version (), "address" , addr , "basePath" , basePath )
224+ return runHTTPServer (ctx , srv , addr , "SSE" )
146225 case "streamable-http" :
147226 opts := []server.StreamableHTTPOption {
148227 server .WithHTTPContextFunc (mcpgrafana .ComposedHTTPContextFunc (gc )),
@@ -153,17 +232,15 @@ func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt
153232 opts = append (opts , server .WithTLSCert (tls .certFile , tls .keyFile ))
154233 }
155234 srv := server .NewStreamableHTTPServer (s , opts ... )
156- slog .Info ("Starting Grafana MCP server using StreamableHTTP transport" , "version" , mcpgrafana .Version (), "address" , addr , "endpointPath" , endpointPath )
157- if err := srv .Start (addr ); err != nil {
158- return fmt .Errorf ("server error: %v" , err )
159- }
235+ slog .Info ("Starting Grafana MCP server using StreamableHTTP transport" ,
236+ "version" , mcpgrafana .Version (), "address" , addr , "endpointPath" , endpointPath )
237+ return runHTTPServer (ctx , srv , addr , "StreamableHTTP" )
160238 default :
161239 return fmt .Errorf (
162240 "invalid transport type: %s. Must be 'stdio', 'sse' or 'streamable-http'" ,
163241 transport ,
164242 )
165243 }
166- return nil
167244}
168245
169246func main () {
0 commit comments