Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mcp/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ type LoggingHandler struct {
handler slog.Handler
}

// discardHandler is a slog.Handler that drops all logs.
// TODO: use slog.DiscardHandler when we require Go 1.24+.
type discardHandler struct{}

func (discardHandler) Enabled(context.Context, slog.Level) bool { return false }
func (discardHandler) Handle(context.Context, slog.Record) error { return nil }
func (discardHandler) WithAttrs([]slog.Attr) slog.Handler { return discardHandler{} }
func (discardHandler) WithGroup(string) slog.Handler { return discardHandler{} }

// ensureLogger returns l if non-nil, otherwise a discard logger.
func ensureLogger(l *slog.Logger) *slog.Logger {
if l != nil {
return l
}
return slog.New(discardHandler{})
}

// NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a
// [slog.JSONHandler].
func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingHandler {
Expand Down
34 changes: 33 additions & 1 deletion mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"iter"
"log/slog"
"maps"
"net/url"
"path/filepath"
Expand Down Expand Up @@ -53,6 +54,8 @@ type Server struct {
type ServerOptions struct {
// Optional instructions for connected clients.
Instructions string
// If non-nil, log server activity.
Logger *slog.Logger
// If non-nil, called when "notifications/initialized" is received.
InitializedHandler func(context.Context, *InitializedRequest)
// PageSize is the maximum number of items to return in a single page for
Expand Down Expand Up @@ -129,6 +132,10 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server {
opts.GetSessionID = randText
}

if opts.Logger == nil { // ensure we have a logger
opts.Logger = ensureLogger(nil)
}

return &Server{
impl: impl,
opts: opts,
Expand Down Expand Up @@ -659,6 +666,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot
sessions := slices.Collect(maps.Keys(subscribedSessions))
s.mu.Unlock()
notifySessions(sessions, notificationResourceUpdated, params)
s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions))
return nil
}

Expand All @@ -676,6 +684,7 @@ func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyRe
s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool)
}
s.resourceSubscriptions[req.Params.URI][req.Session] = true
s.opts.Logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID())

return &emptyResult{}, nil
}
Expand All @@ -697,6 +706,7 @@ func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emp
delete(s.resourceSubscriptions, req.Params.URI)
}
}
s.opts.Logger.Info("resource unsubscribed", "uri", req.Params.URI, "session_id", req.Session.ID())

return &emptyResult{}, nil
}
Expand All @@ -715,8 +725,10 @@ func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emp
// It need not be called on servers that are used for multiple concurrent connections,
// as with [StreamableHTTPHandler].
func (s *Server) Run(ctx context.Context, t Transport) error {
s.opts.Logger.Info("server run start")
ss, err := s.Connect(ctx, t, nil)
if err != nil {
s.opts.Logger.Error("server connect failed", "error", err)
return err
}

Expand All @@ -728,8 +740,14 @@ func (s *Server) Run(ctx context.Context, t Transport) error {
select {
case <-ctx.Done():
ss.Close()
s.opts.Logger.Error("server run cancelled", "error", ctx.Err())
return ctx.Err()
case err := <-ssClosed:
if err != nil {
s.opts.Logger.Error("server session ended with error", "error", err)
} else {
s.opts.Logger.Info("server session ended")
}
return err
}
}
Expand All @@ -745,6 +763,7 @@ func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *Serv
s.mu.Lock()
s.sessions = append(s.sessions, ss)
s.mu.Unlock()
s.opts.Logger.Info("server session connected", "session_id", ss.ID())
return ss
}

Expand All @@ -760,6 +779,7 @@ func (s *Server) disconnect(cc *ServerSession) {
for _, subscribedSessions := range s.resourceSubscriptions {
delete(subscribedSessions, cc)
}
s.opts.Logger.Info("server session disconnected", "session_id", cc.ID())
}

// ServerSessionOptions configures the server session.
Expand All @@ -784,7 +804,14 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp
state = opts.State
onClose = opts.onClose
}
return connect(ctx, t, s, state, onClose)

s.opts.Logger.Info("server connecting")
ss, err := connect(ctx, t, s, state, onClose)
if err != nil {
s.opts.Logger.Error("server connect error", "error", err)
return nil, err
}
return ss, nil
}

// TODO: (nit) move all ServerSession methods below the ServerSession declaration.
Expand All @@ -804,9 +831,11 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar
})

if !wasInit {
ss.server.opts.Logger.Error("initialized before initialize")
return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize)
}
if wasInitd {
ss.server.opts.Logger.Error("duplicate initialized notification")
return nil, fmt.Errorf("duplicate %q received", notificationInitialized)
}
if ss.server.opts.KeepAlive > 0 {
Expand All @@ -815,6 +844,7 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar
if h := ss.server.opts.InitializedHandler; h != nil {
h(ctx, serverRequestFor(ss, params))
}
ss.server.opts.Logger.Info("session initialized")
return nil, nil
}

Expand Down Expand Up @@ -1052,6 +1082,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any,
case methodInitialize, methodPing, notificationInitialized:
default:
if !initialized {
ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method)
return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method)
}
}
Expand Down Expand Up @@ -1108,6 +1139,7 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLoggingLevelPara
ss.updateState(func(state *ServerSessionState) {
state.LogLevel = params.Level
})
ss.server.opts.Logger.Info("client log level set", "level", params.Level)
return &emptyResult{}, nil
}

Expand Down
25 changes: 23 additions & 2 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"io"
"iter"
"log/slog"
"math"
"math/rand/v2"
"net/http"
Expand Down Expand Up @@ -67,6 +68,10 @@ type StreamableHTTPOptions struct {
//
// [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
JSONResponse bool

// Logger specifies the logger to use.
// If nil, do not log.
Logger *slog.Logger
}

// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
Expand All @@ -82,6 +87,11 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
if opts != nil {
h.opts = *opts
}

if h.opts.Logger == nil { // ensure we have a logger
h.opts.Logger = ensureLogger(nil)
}

return h
}

Expand Down Expand Up @@ -367,6 +377,8 @@ type StreamableServerTransport struct {
// StreamableHTTPOptions.JSONResponse is exported.
jsonResponse bool

logger *slog.Logger

// connection is non-nil if and only if the transport has been connected.
connection *streamableServerConn
}
Expand All @@ -381,6 +393,7 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er
stateless: t.Stateless,
eventStore: t.EventStore,
jsonResponse: t.jsonResponse,
logger: t.logger,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
streams: make(map[string]*stream),
Expand All @@ -407,6 +420,8 @@ type streamableServerConn struct {
jsonResponse bool
eventStore EventStore

logger *slog.Logger

incoming chan jsonrpc.Message // messages from the client to the server

mu sync.Mutex // guards all fields below
Expand Down Expand Up @@ -754,7 +769,7 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
}
if _, err := writeEvent(w, e); err != nil {
// Connection closed or broken.
// TODO(#170): log when we add server-side logging.
c.logger.Warn("error writing event", "error", err)
return false
}
writes++
Expand All @@ -773,7 +788,13 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
// simplify.
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
} else {
// TODO(#170): log when we add server-side logging
if ctx.Err() != nil {
// Client disconnected or cancelled the request.
c.logger.Error("stream context done", "error", ctx.Err())
} else {
// Some other error.
c.logger.Error("error receiving message", "error", err)
}
}
return
}
Expand Down