diff --git a/mcp/client.go b/mcp/client.go index 856df567..a8943bcf 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -142,6 +142,18 @@ type ClientSession struct { client *Client initializeResult *InitializeResult keepaliveCancel context.CancelFunc + mcpConn Connection +} + +func (cs *ClientSession) setConn(c Connection) { + cs.mcpConn = c +} + +func (cs *ClientSession) ID() string { + if cs.mcpConn == nil { + return "" + } + return cs.mcpConn.SessionID() } // Close performs a graceful close of the connection, preventing new requests diff --git a/mcp/server.go b/mcp/server.go index 44ec7aa7..1e09767f 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -496,6 +496,7 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot type ServerSession struct { server *Server conn *jsonrpc2.Connection + mcpConn Connection mu sync.Mutex logLevel LoggingLevel initializeParams *InitializeParams @@ -503,6 +504,17 @@ type ServerSession struct { keepaliveCancel context.CancelFunc } +func (ss *ServerSession) setConn(c Connection) { + ss.mcpConn = c +} + +func (ss *ServerSession) ID() string { + if ss.mcpConn == nil { + return "" + } + return ss.mcpConn.SessionID() +} + // Ping pings the client. func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error { _, err := handleSend[*emptyResult](ctx, ss, methodPing, params) diff --git a/mcp/shared.go b/mcp/shared.go index a2d51470..db871ca8 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -36,6 +36,9 @@ type methodHandler any // MethodHandler[*ClientSession] | MethodHandler[*ServerS // A Session is either a ClientSession or a ServerSession. type Session interface { *ClientSession | *ServerSession + // ID returns the session ID, or the empty string if there is none. + ID() string + sendingMethodInfos() map[string]methodInfo receivingMethodInfos() map[string]methodInfo sendingMethodHandler() methodHandler diff --git a/mcp/sse.go b/mcp/sse.go index bdbb6d0f..0a1f9b1b 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -263,6 +263,9 @@ type sseServerConn struct { t *SSEServerTransport } +// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) +func (s sseServerConn) SessionID() string { return "" } + // Read implements jsonrpc2.Reader. func (s sseServerConn) Read(ctx context.Context) (JSONRPCMessage, error) { select { @@ -518,6 +521,9 @@ type sseClientConn struct { done chan struct{} // closed when the stream is closed } +// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) +func (c *sseClientConn) SessionID() string { return "" } + func (c *sseClientConn) isDone() bool { c.mu.Lock() defer c.mu.Unlock() diff --git a/mcp/streamable.go b/mcp/streamable.go index a1952f7e..da950fb2 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -161,6 +161,10 @@ func NewStreamableServerTransport(sessionID string) *StreamableServerTransport { } } +func (t *StreamableServerTransport) SessionID() string { + return t.id +} + // A StreamableServerTransport implements the [Transport] interface for a // single session. type StreamableServerTransport struct { @@ -331,7 +335,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) return } - var requests = make(map[JSONRPCID]struct{}) + requests := make(map[JSONRPCID]struct{}) for _, msg := range incoming { if req, ok := msg.(*JSONRPCRequest); ok && req.ID.IsValid() { requests[req.ID] = struct{}{} @@ -624,20 +628,26 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er } type streamableClientConn struct { - url string - sessionID string - client *http.Client - incoming chan []byte - done chan struct{} + url string + client *http.Client + incoming chan []byte + done chan struct{} closeOnce sync.Once closeErr error - mu sync.Mutex + mu sync.Mutex + _sessionID string // bodies map[*http.Response]io.Closer err error } +func (c *streamableClientConn) SessionID() string { + c.mu.Lock() + defer c.mu.Unlock() + return c._sessionID +} + // Read implements the [Connection] interface. func (s *streamableClientConn) Read(ctx context.Context) (JSONRPCMessage, error) { select { @@ -658,7 +668,7 @@ func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) er return s.err } - sessionID := s.sessionID + sessionID := s._sessionID if sessionID == "" { // Hold lock for the first request. defer s.mu.Unlock() @@ -681,7 +691,7 @@ func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) er if sessionID == "" { // locked - s.sessionID = gotSessionID + s._sessionID = gotSessionID } return nil @@ -753,7 +763,7 @@ func (s *streamableClientConn) Close() error { if err != nil { s.closeErr = err } else { - req.Header.Set("Mcp-Session-Id", s.sessionID) + req.Header.Set("Mcp-Session-Id", s._sessionID) if _, err := s.client.Do(req); err != nil { s.closeErr = err } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d496f83e..12c00a2e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -69,7 +69,10 @@ func TestStreamableTransports(t *testing.T) { t.Fatalf("client.Connect() failed: %v", err) } defer session.Close() - + sid := session.ID() + if sid == "" { + t.Error("empty session ID") + } // 4. The client calls the "greet" tool. params := &CallToolParams{ Name: "greet", @@ -79,6 +82,9 @@ func TestStreamableTransports(t *testing.T) { if err != nil { t.Fatalf("CallTool() failed: %v", err) } + if g := session.ID(); g != sid { + t.Errorf("session ID: got %q, want %q", g, sid) + } // 5. Verify that the correct response is received. want := &CallToolResult{ diff --git a/mcp/transport.go b/mcp/transport.go index 0fadca33..85bfaf65 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -50,6 +50,7 @@ type Connection interface { Read(context.Context) (JSONRPCMessage, error) Write(context.Context, JSONRPCMessage) error Close() error // may be called concurrently by both peers + SessionID() string } // A StdioTransport is a [Transport] that communicates over stdin/stdout using @@ -94,6 +95,7 @@ type binder[T handler] interface { type handler interface { handle(ctx context.Context, req *JSONRPCRequest) (any, error) + setConn(Connection) } func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error) { @@ -124,6 +126,7 @@ func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error }, }) assert(preempter.conn != nil, "unbound preempter") + h.setConn(conn) return h, nil } @@ -200,6 +203,8 @@ type loggingConn struct { w io.Writer } +func (c *loggingConn) SessionID() string { return c.delegate.SessionID() } + // loggingReader is a stream middleware that logs incoming messages. func (s *loggingConn) Read(ctx context.Context) (JSONRPCMessage, error) { msg, err := s.delegate.Read(ctx) @@ -285,6 +290,8 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { } } +func (c *ioConn) SessionID() string { return "" } + // addBatch records a msgBatch for an incoming batch payload. // It returns an error if batch is malformed, containing previously seen IDs. //