Skip to content

Commit adc6f94

Browse files
committed
mcp/server.go: implemented server-side logging throughout codebase
1 parent 87f2224 commit adc6f94

File tree

4 files changed

+90
-13
lines changed

4 files changed

+90
-13
lines changed

mcp/logging.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,27 @@ type LoggingHandler struct {
8787
handler slog.Handler
8888
}
8989

90+
// discardHandler is a slog.Handler that drops all logs.
91+
type discardHandler struct{}
92+
93+
func (discardHandler) Enabled(context.Context, slog.Level) bool { return false }
94+
func (discardHandler) Handle(context.Context, slog.Record) error { return nil }
95+
func (discardHandler) WithAttrs([]slog.Attr) slog.Handler { return discardHandler{} }
96+
func (discardHandler) WithGroup(string) slog.Handler { return discardHandler{} }
97+
98+
// ensureLogger returns l if non-nil, otherwise a discard logger.
99+
func ensureLogger(l *slog.Logger) *slog.Logger {
100+
if l != nil {
101+
return l
102+
}
103+
return slog.New(discardHandler{})
104+
}
105+
106+
// internalLogger is used for package-internal logging where we don't have a
107+
// specific server/handler context. It defaults to a discard logger to avoid
108+
// unsolicited output from library code.
109+
var internalLogger = slog.New(discardHandler{})
110+
90111
// NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a
91112
// [slog.JSONHandler].
92113
func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingHandler {

mcp/server.go

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"encoding/json"
1313
"fmt"
1414
"iter"
15+
"log/slog"
1516
"maps"
1617
"net/url"
1718
"path/filepath"
@@ -32,8 +33,9 @@ const DefaultPageSize = 1000
3233
// sessions by using [Server.Run].
3334
type Server struct {
3435
// fixed at creation
35-
impl *Implementation
36-
opts ServerOptions
36+
impl *Implementation
37+
opts ServerOptions
38+
logger *slog.Logger
3739

3840
mu sync.Mutex
3941
prompts *featureSet[*serverPrompt]
@@ -50,6 +52,8 @@ type Server struct {
5052
type ServerOptions struct {
5153
// Optional instructions for connected clients.
5254
Instructions string
55+
// Logger is used for server-side logging. If nil, slog.Default() is used.
56+
Logger *slog.Logger
5357
// If non-nil, called when "notifications/initialized" is received.
5458
InitializedHandler func(context.Context, *ServerRequest[*InitializedParams])
5559
// PageSize is the maximum number of items to return in a single page for
@@ -108,9 +112,11 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server {
108112
if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil {
109113
panic("UnsubscribeHandler requires SubscribeHandler")
110114
}
115+
l := ensureLogger(opts.Logger)
111116
return &Server{
112117
impl: impl,
113118
opts: *opts,
119+
logger: l,
114120
prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }),
115121
tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }),
116122
resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }),
@@ -462,6 +468,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot
462468
sessions := slices.Collect(maps.Keys(subscribedSessions))
463469
s.mu.Unlock()
464470
notifySessions(sessions, notificationResourceUpdated, params)
471+
s.logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions))
465472
return nil
466473
}
467474

@@ -479,6 +486,7 @@ func (s *Server) subscribe(ctx context.Context, req *ServerRequest[*SubscribePar
479486
s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool)
480487
}
481488
s.resourceSubscriptions[req.Params.URI][req.Session] = true
489+
s.logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID())
482490

483491
return &emptyResult{}, nil
484492
}
@@ -500,6 +508,7 @@ func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*Unsubscrib
500508
delete(s.resourceSubscriptions, req.Params.URI)
501509
}
502510
}
511+
s.logger.Info("resource unsubscribed", "uri", req.Params.URI, "session_id", req.Session.ID())
503512

504513
return &emptyResult{}, nil
505514
}
@@ -518,8 +527,10 @@ func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*Unsubscrib
518527
// It need not be called on servers that are used for multiple concurrent connections,
519528
// as with [StreamableHTTPHandler].
520529
func (s *Server) Run(ctx context.Context, t Transport) error {
530+
s.logger.Info("server run start")
521531
ss, err := s.Connect(ctx, t, nil)
522532
if err != nil {
533+
s.logger.Error("server connect failed", "error", err)
523534
return err
524535
}
525536

@@ -531,8 +542,14 @@ func (s *Server) Run(ctx context.Context, t Transport) error {
531542
select {
532543
case <-ctx.Done():
533544
ss.Close()
545+
s.logger.Info("server run cancelled", "error", ctx.Err())
534546
return ctx.Err()
535547
case err := <-ssClosed:
548+
if err != nil {
549+
s.logger.Error("server session ended with error", "error", err)
550+
} else {
551+
s.logger.Info("server session ended")
552+
}
536553
return err
537554
}
538555
}
@@ -548,6 +565,7 @@ func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *Serv
548565
s.mu.Lock()
549566
s.sessions = append(s.sessions, ss)
550567
s.mu.Unlock()
568+
s.logger.Info("server session connected", "session_id", ss.ID())
551569
return ss
552570
}
553571

@@ -563,6 +581,7 @@ func (s *Server) disconnect(cc *ServerSession) {
563581
for _, subscribedSessions := range s.resourceSubscriptions {
564582
delete(subscribedSessions, cc)
565583
}
584+
s.logger.Info("server session disconnected", "session_id", cc.ID())
566585
}
567586

568587
// ServerSessionOptions configures the server session.
@@ -583,7 +602,13 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp
583602
if opts != nil {
584603
state = opts.State
585604
}
586-
return connect(ctx, t, s, state)
605+
s.logger.Info("server connecting")
606+
ss, err := connect(ctx, t, s, state)
607+
if err != nil {
608+
s.logger.Error("server connect error", "error", err)
609+
return nil, err
610+
}
611+
return ss, nil
587612
}
588613

589614
// TODO: (nit) move all ServerSession methods below the ServerSession declaration.
@@ -606,14 +631,17 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar
606631
})
607632

608633
if !wasInit {
634+
ss.server.logger.Warn("initialized before initialize")
609635
return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize)
610636
}
611637
if wasInitd {
638+
ss.server.logger.Warn("duplicate initialized notification")
612639
return nil, fmt.Errorf("duplicate %q received", notificationInitialized)
613640
}
614641
if h := ss.server.opts.InitializedHandler; h != nil {
615642
h(ctx, serverRequestFor(ss, params))
616643
}
644+
ss.server.logger.Info("session initialized")
617645
return nil, nil
618646
}
619647

@@ -798,6 +826,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any,
798826
case methodInitialize, methodPing, notificationInitialized:
799827
default:
800828
if !initialized {
829+
ss.server.logger.Warn("method invalid during initialization", "method", req.Method)
801830
return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method)
802831
}
803832
}
@@ -842,6 +871,7 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*e
842871
ss.updateState(func(state *ServerSessionState) {
843872
state.LogLevel = params.Level
844873
})
874+
ss.server.logger.Info("client log level set", "level", params.Level)
845875
return &emptyResult{}, nil
846876
}
847877

mcp/sse.go

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"fmt"
1111
"io"
12+
"log/slog"
1213
"net/http"
1314
"net/url"
1415
"sync"
@@ -47,6 +48,7 @@ type SSEHandler struct {
4748

4849
mu sync.Mutex
4950
sessions map[string]*SSEServerTransport
51+
logger *slog.Logger
5052
}
5153

5254
// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP
@@ -68,9 +70,12 @@ func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler {
6870
return &SSEHandler{
6971
getServer: getServer,
7072
sessions: make(map[string]*SSEServerTransport),
73+
logger: internalLogger,
7174
}
7275
}
7376

77+
func (h *SSEHandler) ensureLogger() { h.logger = ensureLogger(h.logger) }
78+
7479
// A SSEServerTransport is a logical SSE session created through a hanging GET
7580
// request.
7681
//
@@ -100,6 +105,10 @@ type SSEServerTransport struct {
100105
// Response is the hanging response body to the incoming GET request.
101106
Response http.ResponseWriter
102107

108+
// logger is used for per-POST diagnostics and transport-level logs.
109+
// If nil, logging is disabled.
110+
logger *slog.Logger
111+
103112
// incoming is the queue of incoming messages.
104113
// It is never closed, and by convention, incoming is non-nil if and only if
105114
// the transport is connected.
@@ -114,6 +123,8 @@ type SSEServerTransport struct {
114123
done chan struct{} // closed when the connection is closed
115124
}
116125

126+
func (t *SSEServerTransport) ensureLogger() { t.logger = ensureLogger(t.logger) }
127+
117128
// NewSSEServerTransport creates a new SSE transport for the given messages
118129
// endpoint, and hanging GET response.
119130
//
@@ -124,11 +135,13 @@ func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTra
124135
return &SSEServerTransport{
125136
Endpoint: endpoint,
126137
Response: w,
138+
logger: ensureLogger(nil),
127139
}
128140
}
129141

130142
// ServeHTTP handles POST requests to the transport endpoint.
131143
func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) {
144+
t.ensureLogger()
132145
if t.incoming == nil {
133146
http.Error(w, "session not connected", http.StatusInternalServerError)
134147
return
@@ -137,6 +150,7 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request)
137150
// Read and parse the message.
138151
data, err := io.ReadAll(req.Body)
139152
if err != nil {
153+
t.logger.Error("sse: failed to read body", "error", err)
140154
http.Error(w, "failed to read body", http.StatusBadRequest)
141155
return
142156
}
@@ -145,11 +159,13 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request)
145159
// useful
146160
msg, err := jsonrpc2.DecodeMessage(data)
147161
if err != nil {
162+
t.logger.Error("sse: failed to parse body", "error", err)
148163
http.Error(w, "failed to parse body", http.StatusBadRequest)
149164
return
150165
}
151166
if req, ok := msg.(*jsonrpc.Request); ok {
152167
if _, err := checkRequest(req, serverMethodInfos); err != nil {
168+
t.logger.Warn("sse: request validation failed", "error", err)
153169
http.Error(w, err.Error(), http.StatusBadRequest)
154170
return
155171
}
@@ -158,6 +174,7 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request)
158174
case t.incoming <- msg:
159175
w.WriteHeader(http.StatusAccepted)
160176
case <-t.done:
177+
t.logger.Info("sse: session closed while posting message")
161178
http.Error(w, "session closed", http.StatusBadRequest)
162179
}
163180
}
@@ -181,6 +198,7 @@ func (t *SSEServerTransport) Connect(context.Context) (Connection, error) {
181198
}
182199

183200
func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
201+
h.ensureLogger()
184202
sessionID := req.URL.Query().Get("sessionid")
185203

186204
// TODO: consider checking Content-Type here. For now, we are lax.
@@ -221,11 +239,24 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
221239
sessionID = randText()
222240
endpoint, err := req.URL.Parse("?sessionid=" + sessionID)
223241
if err != nil {
242+
h.logger.Error("sse: failed to create endpoint", "error", err)
224243
http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError)
225244
return
226245
}
227246

228-
transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w}
247+
// Determine the server instance and pick a logger for the transport.
248+
server := h.getServer(req)
249+
if server == nil {
250+
// The getServer argument to NewSSEHandler returned nil.
251+
http.Error(w, "no server available", http.StatusBadRequest)
252+
return
253+
}
254+
// Prefer the server's logger if available; otherwise use the handler's.
255+
lg := server.logger
256+
if lg == nil {
257+
lg = h.logger
258+
}
259+
transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w, logger: ensureLogger(lg)}
229260

230261
// The session is terminated when the request exits.
231262
h.mu.Lock()
@@ -236,15 +267,9 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
236267
delete(h.sessions, sessionID)
237268
h.mu.Unlock()
238269
}()
239-
240-
server := h.getServer(req)
241-
if server == nil {
242-
// The getServer argument to NewSSEHandler returned nil.
243-
http.Error(w, "no server available", http.StatusBadRequest)
244-
return
245-
}
246270
ss, err := server.Connect(req.Context(), transport, nil)
247271
if err != nil {
272+
h.logger.Error("sse: server connect failed", "error", err)
248273
http.Error(w, "connection failed", http.StatusInternalServerError)
249274
return
250275
}

mcp/transport.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13-
"log"
1413
"net"
1514
"os"
1615
"sync"
@@ -157,7 +156,9 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H,
157156
OnDone: func() {
158157
b.disconnect(h)
159158
},
160-
OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) },
159+
OnInternalError: func(err error) {
160+
internalLogger.Error("jsonrpc2 internal error", "error", err)
161+
},
161162
})
162163
assert(preempter.conn != nil, "unbound preempter")
163164
return h, nil

0 commit comments

Comments
 (0)