From 93cb62325c6573d5daeaf7ded543e429777b0685 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 7 Jul 2025 09:52:59 -0400 Subject: [PATCH] mcp: add SSE session IDs ClientSessions and ServerSessions created by SSE have non-empty session IDs. Fixes #98. --- mcp/sse.go | 18 ++++++++++-------- mcp/sse_test.go | 11 +++++++++-- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/mcp/sse.go b/mcp/sse.go index 0a1f9b1b..6c7324bb 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -110,8 +110,9 @@ func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { // [SSEServerTransport.ServeHTTP]. // - Close terminates the hanging GET. type SSEServerTransport struct { - endpoint string - incoming chan JSONRPCMessage // queue of incoming messages; never closed + endpoint string + sessionID string + incoming chan JSONRPCMessage // queue of incoming messages; never closed // We must guard both pushes to the incoming queue and writes to the response // writer, because incoming POST requests are arbitrarily concurrent and we @@ -228,6 +229,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } transport := NewSSEServerTransport(endpoint.RequestURI(), w) + transport.sessionID = sessionID // The session is terminated when the request exits. h.mu.Lock() @@ -263,8 +265,7 @@ type sseServerConn struct { t *SSEServerTransport } -// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) -func (s sseServerConn) SessionID() string { return "" } +func (s sseServerConn) SessionID() string { return s.t.sessionID } // Read implements jsonrpc2.Reader. func (s sseServerConn) Read(ctx context.Context) (JSONRPCMessage, error) { @@ -393,6 +394,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { s := &sseClientConn{ sseEndpoint: c.sseEndpoint, msgEndpoint: msgEndpoint, + sessionID: msgEndpoint.Query().Get("sessionid"), incoming: make(chan []byte, 100), body: resp.Body, done: make(chan struct{}), @@ -511,8 +513,9 @@ func scanEvents(r io.Reader) iter.Seq2[event, error] { // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientConn struct { - sseEndpoint *url.URL // SSE endpoint for the GET - msgEndpoint *url.URL // session endpoint for POSTs + sseEndpoint *url.URL // SSE endpoint for the GET + msgEndpoint *url.URL // session endpoint for POSTs + sessionID string incoming chan []byte // queue of incoming messages mu sync.Mutex @@ -521,8 +524,7 @@ type sseClientConn struct { done chan struct{} // closed when the stream is closed } -// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) -func (c *sseClientConn) SessionID() string { return "" } +func (c *sseClientConn) SessionID() string { return c.sessionID } func (c *sseClientConn) isDone() bool { c.mu.Lock() diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 23621931..2d12e72e 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -25,9 +25,12 @@ func TestSSEServer(t *testing.T) { sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) conns := make(chan *ServerSession, 1) - sseHandler.onConnection = func(cc *ServerSession) { + sseHandler.onConnection = func(ss *ServerSession) { + if ss.ID() == "" { + t.Error("ServerSession has empty session ID") + } select { - case conns <- cc: + case conns <- ss: default: } } @@ -41,6 +44,10 @@ func TestSSEServer(t *testing.T) { if err != nil { t.Fatal(err) } + if cs.ID() == "" { + t.Error("ClientSession has empty ID") + } + if err := cs.Ping(ctx, nil); err != nil { t.Fatal(err) }