From c5907a306f17ca124f183265a7130f538b2f5a60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Friedrich=20Gro=C3=9Fe?= Date: Wed, 17 Sep 2025 15:22:31 +0200 Subject: [PATCH 1/6] mcp: add OnConnectionClose.OnConnectionClose(...) callback Adding a new callback that callers can optionally set in order to get notified about a connection created by a `StreamableHTTPHandler` was closed. Connections can be closed by the MCP client as part of the regular connection lifecycle (happy path) or when there was a connection error (e.g., a timeout). Fixes #479 --- mcp/streamable.go | 30 +++++++++++++++++++---- mcp/streamable_test.go | 54 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 8ac6f59a..97b6d3e0 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -40,8 +40,6 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions - onTransportDeletion func(sessionID string) // for testing only - mu sync.Mutex // TODO: we should store the ServerSession along with the transport, because // we need to cancel keepalive requests when closing the transport. @@ -77,6 +75,11 @@ type StreamableHTTPOptions struct { // // [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server JSONResponse bool + + // OnConnectionClose is a callback function that is invoked when a [Connection] + // is closed. A connection is closed when the session is ended explicitly by + // the client or when it is interrupted due to a timeout or other errors. + OnConnectionClose func(sessionID string) } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -166,7 +169,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.mu.Lock() delete(h.transports, transport.SessionID) h.mu.Unlock() - transport.connection.Close() + _ = transport.Close() } w.WriteHeader(http.StatusNoContent) return @@ -299,8 +302,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.mu.Lock() delete(h.transports, transport.SessionID) h.mu.Unlock() - if h.onTransportDeletion != nil { - h.onTransportDeletion(transport.SessionID) + if h.opts.OnConnectionClose != nil { + h.opts.OnConnectionClose(transport.SessionID) } }, } @@ -320,6 +323,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } else { // Otherwise, save the transport so that it can be reused h.mu.Lock() + transport.session = ss h.transports[transport.SessionID] = transport h.mu.Unlock() } @@ -382,6 +386,9 @@ type StreamableServerTransport struct { // connection is non-nil if and only if the transport has been connected. connection *streamableServerConn + + // the server session associated with this transport + session *ServerSession } // Connect implements the [Transport] interface. @@ -563,6 +570,19 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R } } +// Close releases resources related to this transport if it has already been connected. +func (t *StreamableServerTransport) Close() error { + var sessionErr, connErr error + if t.session != nil { + sessionErr = t.session.Close() + } + if t.connection != nil { + connErr = t.connection.Close() + } + + return errors.Join(sessionErr, connErr) +} + // serveGET streams messages to a hanging http GET, with stream ID and last // message parsed from the Last-Event-ID header. // diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e077308c..066e4b95 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -130,7 +130,12 @@ func TestStreamableTransports(t *testing.T) { if err != nil { t.Fatalf("client.Connect() failed: %v", err) } - defer session.Close() + t.Cleanup(func() { + err := session.Close() + if err != nil { + t.Errorf("session.Close() failed: %v", err) + } + }) sid := session.ID() if sid == "" { t.Fatalf("empty session ID") @@ -220,7 +225,7 @@ func TestStreamableServerShutdown(t *testing.T) { httpServer := httptest.NewUnstartedServer(handler) httpServer.Config.RegisterOnShutdown(func() { for session := range server.Sessions() { - session.Close() + _ = session.Close() } }) httpServer.Start() @@ -428,12 +433,11 @@ func TestServerTransportCleanup(t *testing.T) { chans[fmt.Sprint(id)] = make(chan struct{}, 1) return fmt.Sprint(id) }, + OnConnectionClose: func(sessionID string) { + chans[sessionID] <- struct{}{} + }, }) - handler.onTransportDeletion = func(sessionID string) { - chans[sessionID] <- struct{}{} - } - httpServer := httptest.NewServer(handler) defer httpServer.Close() @@ -1423,3 +1427,41 @@ func TestStreamableGET(t *testing.T) { t.Errorf("GET with session ID: got status %d, want %d", got, want) } } + +// TestStreamableHTTPHandler_OnConnectionClose_SessionDeletion tests that the +// OnConnectionClose callback is called when the client closes the session. +func TestStreamableHTTPHandler_OnConnectionClose_SessionDeletion(t *testing.T) { + var closedConnections []string + + server := NewServer(testImpl, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + OnConnectionClose: func(sessionID string) { + closedConnections = append(closedConnections, sessionID) + }, + }) + + httpServer := httptest.NewServer(handler) + t.Cleanup(httpServer.Close) + + ctx := context.Background() + client := NewClient(testImpl, nil) + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + + sessionID := session.ID() + t.Log("Closing client session") + err = session.Close() + if err != nil { + t.Fatalf("session.Close() failed: %v", err) + } + + if len(closedConnections) != 1 { + t.Fatalf("got %d connections, want 1", len(closedConnections)) + } + if closedConnections[0] != sessionID { + t.Fatalf("got session ID %q, want %q", closedConnections[0], sessionID) + } +} From c999042ddd73f9bea234582f7649a34b7fcf9ecf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Friedrich=20Gro=C3=9Fe?= Date: Wed, 17 Sep 2025 16:20:47 +0200 Subject: [PATCH 2/6] mcp: rename OnConnectionClose to OnSessionClose and remove StreamableServerTransport.Close() --- mcp/streamable.go | 29 +++++++---------------------- mcp/streamable_test.go | 10 +++++----- 2 files changed, 12 insertions(+), 27 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 97b6d3e0..05a80f59 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -40,9 +40,7 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions - mu sync.Mutex - // TODO: we should store the ServerSession along with the transport, because - // we need to cancel keepalive requests when closing the transport. + mu sync.Mutex transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } @@ -76,10 +74,10 @@ type StreamableHTTPOptions struct { // [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server JSONResponse bool - // OnConnectionClose is a callback function that is invoked when a [Connection] + // OnSessionClose is a callback function that is invoked when a [Connection] // is closed. A connection is closed when the session is ended explicitly by // the client or when it is interrupted due to a timeout or other errors. - OnConnectionClose func(sessionID string) + OnSessionClose func(sessionID string) } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -169,7 +167,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.mu.Lock() delete(h.transports, transport.SessionID) h.mu.Unlock() - _ = transport.Close() + _ = transport.session.Close() } w.WriteHeader(http.StatusNoContent) return @@ -302,8 +300,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.mu.Lock() delete(h.transports, transport.SessionID) h.mu.Unlock() - if h.opts.OnConnectionClose != nil { - h.opts.OnConnectionClose(transport.SessionID) + if h.opts.OnSessionClose != nil { + h.opts.OnSessionClose(transport.SessionID) } }, } @@ -387,7 +385,7 @@ type StreamableServerTransport struct { // connection is non-nil if and only if the transport has been connected. connection *streamableServerConn - // the server session associated with this transport + // the server session associated with this transport. session *ServerSession } @@ -570,19 +568,6 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R } } -// Close releases resources related to this transport if it has already been connected. -func (t *StreamableServerTransport) Close() error { - var sessionErr, connErr error - if t.session != nil { - sessionErr = t.session.Close() - } - if t.connection != nil { - connErr = t.connection.Close() - } - - return errors.Join(sessionErr, connErr) -} - // serveGET streams messages to a hanging http GET, with stream ID and last // message parsed from the Last-Event-ID header. // diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 066e4b95..7036d129 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -433,7 +433,7 @@ func TestServerTransportCleanup(t *testing.T) { chans[fmt.Sprint(id)] = make(chan struct{}, 1) return fmt.Sprint(id) }, - OnConnectionClose: func(sessionID string) { + OnSessionClose: func(sessionID string) { chans[sessionID] <- struct{}{} }, }) @@ -1428,14 +1428,14 @@ func TestStreamableGET(t *testing.T) { } } -// TestStreamableHTTPHandler_OnConnectionClose_SessionDeletion tests that the -// OnConnectionClose callback is called when the client closes the session. -func TestStreamableHTTPHandler_OnConnectionClose_SessionDeletion(t *testing.T) { +// TestStreamableHTTPHandler_OnSessionClose_SessionDeletion tests that the +// OnSessionClose callback is called when the client closes the session. +func TestStreamableHTTPHandler_OnSessionClose_SessionDeletion(t *testing.T) { var closedConnections []string server := NewServer(testImpl, nil) handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ - OnConnectionClose: func(sessionID string) { + OnSessionClose: func(sessionID string) { closedConnections = append(closedConnections, sessionID) }, }) From 6788e58292755e0b50c5c8e2dba37d4342475c72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Friedrich=20Gro=C3=9Fe?= Date: Wed, 17 Sep 2025 16:43:47 +0200 Subject: [PATCH 3/6] mcp: improve comment on OnSessionClose callback option --- mcp/streamable.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 05a80f59..ee1e7a29 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -74,9 +74,9 @@ type StreamableHTTPOptions struct { // [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server JSONResponse bool - // OnSessionClose is a callback function that is invoked when a [Connection] - // is closed. A connection is closed when the session is ended explicitly by - // the client or when it is interrupted due to a timeout or other errors. + // OnSessionClose is a callback function that is invoked when a [ServerSession] + // is closed. This happens when a session is ended explicitly by the MCP client + // or when it is interrupted due to a timeout or other errors. OnSessionClose func(sessionID string) } From 50b33370cdf6279ce2aaf00613a8ad2e4c6d1830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Friedrich=20Gro=C3=9Fe?= Date: Wed, 17 Sep 2025 16:46:04 +0200 Subject: [PATCH 4/6] mcp: use "closed session" instead of "closed connection" in unit tests --- mcp/streamable_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 7036d129..5ebe4e39 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1431,12 +1431,12 @@ func TestStreamableGET(t *testing.T) { // TestStreamableHTTPHandler_OnSessionClose_SessionDeletion tests that the // OnSessionClose callback is called when the client closes the session. func TestStreamableHTTPHandler_OnSessionClose_SessionDeletion(t *testing.T) { - var closedConnections []string + var closedSessions []string server := NewServer(testImpl, nil) handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ OnSessionClose: func(sessionID string) { - closedConnections = append(closedConnections, sessionID) + closedSessions = append(closedSessions, sessionID) }, }) @@ -1458,10 +1458,10 @@ func TestStreamableHTTPHandler_OnSessionClose_SessionDeletion(t *testing.T) { t.Fatalf("session.Close() failed: %v", err) } - if len(closedConnections) != 1 { - t.Fatalf("got %d connections, want 1", len(closedConnections)) + if len(closedSessions) != 1 { + t.Fatalf("got %d closed sessions, want 1", len(closedSessions)) } - if closedConnections[0] != sessionID { - t.Fatalf("got session ID %q, want %q", closedConnections[0], sessionID) + if closedSessions[0] != sessionID { + t.Fatalf("got session ID %q, want %q", closedSessions[0], sessionID) } } From eab40f7c356fa8a50978c4b898cc716fe7db723c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Friedrich=20Gro=C3=9Fe?= Date: Thu, 18 Sep 2025 15:43:29 +0200 Subject: [PATCH 5/6] mcp: add TODO for missing error logging --- mcp/streamable.go | 1 + 1 file changed, 1 insertion(+) diff --git a/mcp/streamable.go b/mcp/streamable.go index ee1e7a29..fc5d0ffa 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -167,6 +167,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.mu.Lock() delete(h.transports, transport.SessionID) h.mu.Unlock() + // TODO: consider logging this error _ = transport.session.Close() } w.WriteHeader(http.StatusNoContent) From f3b5d9367bed69076e6a38880fa1d736ea08bc82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Friedrich=20Gro=C3=9Fe?= Date: Thu, 18 Sep 2025 21:15:44 +0200 Subject: [PATCH 6/6] mcp: fix failing unit test --- mcp/streamable_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 328b22ea..ed619704 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -103,7 +103,7 @@ func TestStreamableTransports(t *testing.T) { } handler.ServeHTTP(w, r) })) - defer httpServer.Close() + t.Cleanup(func() { httpServer.Close() }) // Create a client and connect it to the server using our StreamableClientTransport. // Check that all requests honor a custom client.