diff --git a/examples/server/sse/main.go b/examples/server/sse/main.go index 27f9caed..0507dd60 100644 --- a/examples/server/sse/main.go +++ b/examples/server/sse/main.go @@ -65,6 +65,6 @@ func main() { default: return nil } - }) + }, nil) log.Fatal(http.ListenAndServe(addr, handler)) } diff --git a/mcp/logging.go b/mcp/logging.go index b3186a96..a1c031ac 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -88,6 +88,23 @@ type LoggingHandler struct { handler slog.Handler } +// discardHandler is a slog.Handler that drops all logs. +// TODO: use slog.NewNopHandler when we require Go 1.24+. +type discardHandler struct{} + +func (discardHandler) Enabled(context.Context, slog.Level) bool { return false } +func (discardHandler) Handle(context.Context, slog.Record) error { return nil } +func (discardHandler) WithAttrs([]slog.Attr) slog.Handler { return discardHandler{} } +func (discardHandler) WithGroup(string) slog.Handler { return discardHandler{} } + +// ensureLogger returns l if non-nil, otherwise a discard logger. +func ensureLogger(l *slog.Logger) *slog.Logger { + if l != nil { + return l + } + return slog.New(discardHandler{}) +} + // NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a // [slog.JSONHandler]. func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingHandler { diff --git a/mcp/sse.go b/mcp/sse.go index f39a0397..f3b8cf34 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -9,6 +9,7 @@ import ( "context" "fmt" "io" + "log/slog" "net/http" "net/url" "sync" @@ -43,12 +44,21 @@ import ( // [2024-11-05 version]: https://modelcontextprotocol.io/specification/2024-11-05/basic/transports type SSEHandler struct { getServer func(request *http.Request) *Server + opts SSEOptions onConnection func(*ServerSession) // for testing; must not block + logger *slog.Logger mu sync.Mutex sessions map[string]*SSEServerTransport } +// SSEOptions specifies options for an [SSEHandler]. +type SSEOptions struct { + // Logger specifies the logger to use. + // If nil, do not log. + Logger *slog.Logger +} + // NewSSEHandler returns a new [SSEHandler] that creates and manages MCP // sessions created via incoming HTTP requests. // @@ -62,13 +72,22 @@ type SSEHandler struct { // The getServer function may return a distinct [Server] for each new // request, or reuse an existing server. If it returns nil, the handler // will return a 400 Bad Request. -// -// TODO(rfindley): add options. -func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { - return &SSEHandler{ +func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptions) *SSEHandler { + s := &SSEHandler{ getServer: getServer, sessions: make(map[string]*SSEServerTransport), } + + if opts != nil { + s.opts = *opts + } + + if s.opts.Logger == nil { // ensure we have a logger + s.opts.Logger = ensureLogger(nil) + } + s.logger = s.opts.Logger + + return s } // A SSEServerTransport is a logical SSE session created through a hanging GET @@ -100,6 +119,10 @@ type SSEServerTransport struct { // Response is the hanging response body to the incoming GET request. Response http.ResponseWriter + // logger is used for per-POST diagnostics and transport-level logs. + // If nil, logging is disabled. + logger *slog.Logger + // incoming is the queue of incoming messages. // It is never closed, and by convention, incoming is non-nil if and only if // the transport is connected. @@ -124,6 +147,7 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) // Read and parse the message. data, err := io.ReadAll(req.Body) if err != nil { + t.logger.Error("sse: failed to read body", "error", err) http.Error(w, "failed to read body", http.StatusBadRequest) return } @@ -132,11 +156,13 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) // useful msg, err := jsonrpc2.DecodeMessage(data) if err != nil { + t.logger.Error("sse: failed to parse body", "error", err) http.Error(w, "failed to parse body", http.StatusBadRequest) return } if req, ok := msg.(*jsonrpc.Request); ok { if _, err := checkRequest(req, serverMethodInfos); err != nil { + t.logger.Warn("sse: request validation failed", "error", err) http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -145,6 +171,7 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) case t.incoming <- msg: w.WriteHeader(http.StatusAccepted) case <-t.done: + t.logger.Warn("sse: session closed while posting message") http.Error(w, "session closed", http.StatusBadRequest) } } @@ -208,11 +235,12 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { sessionID = randText() endpoint, err := req.URL.Parse("?sessionid=" + sessionID) if err != nil { + h.logger.Error("sse: failed to create endpoint", "error", err) http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError) return } - transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w} + transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w, logger: h.logger} // The session is terminated when the request exits. h.mu.Lock() @@ -232,6 +260,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } ss, err := server.Connect(req.Context(), transport, nil) if err != nil { + h.logger.Error("sse: server connect failed", "error", err) http.Error(w, "connection failed", http.StatusInternalServerError) return } diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index d06ea62b..6132d31e 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -31,7 +31,7 @@ func ExampleSSEHandler() { server := mcp.NewServer(&mcp.Implementation{Name: "adder", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add two numbers"}, Add) - handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }) + handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }, nil) httpServer := httptest.NewServer(handler) defer httpServer.Close() diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 32a20bf3..b8662f71 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -24,7 +24,11 @@ func TestSSEServer(t *testing.T) { server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet"}, sayHi) - sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) + sseOptions := &SSEOptions{ + Logger: ensureLogger(nil), + } + + sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, sseOptions) serverSessions := make(chan *ServerSession, 1) sseHandler.onConnection = func(ss *ServerSession) {