diff --git a/mcp/streamable.go b/mcp/streamable.go index 20eb13d5..8e83c77d 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -41,11 +41,7 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions - onTransportDeletion func(sessionID string) // for testing only - - mu sync.Mutex - // TODO: we should store the ServerSession along with the transport, because - // we need to cancel keepalive requests when closing the transport. + mu sync.Mutex transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } @@ -72,6 +68,11 @@ type StreamableHTTPOptions struct { // Logger specifies the logger to use. // If nil, do not log. Logger *slog.Logger + + // OnSessionClose is a callback function that is invoked when a [ServerSession] + // is closed. This happens when a session is ended explicitly by the MCP client + // or when it is interrupted due to a timeout or other errors. + OnSessionClose func(sessionID string) } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -163,7 +164,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.mu.Lock() delete(h.transports, transport.SessionID) h.mu.Unlock() - transport.connection.Close() + // TODO: consider logging this error + _ = transport.session.Close() } w.WriteHeader(http.StatusNoContent) return @@ -297,8 +299,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.mu.Lock() delete(h.transports, transport.SessionID) h.mu.Unlock() - if h.onTransportDeletion != nil { - h.onTransportDeletion(transport.SessionID) + if h.opts.OnSessionClose != nil { + h.opts.OnSessionClose(transport.SessionID) } }, } @@ -318,6 +320,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } else { // Otherwise, save the transport so that it can be reused h.mu.Lock() + transport.session = ss h.transports[transport.SessionID] = transport h.mu.Unlock() } @@ -386,6 +389,9 @@ type StreamableServerTransport struct { // connection is non-nil if and only if the transport has been connected. connection *streamableServerConn + + // the server session associated with this transport. + session *ServerSession } // Connect implements the [Transport] interface. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 6ccaebf7..853488d4 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -105,7 +105,7 @@ func TestStreamableTransports(t *testing.T) { } handler.ServeHTTP(w, r) }))) - defer httpServer.Close() + t.Cleanup(func() { httpServer.Close() }) // Create a client and connect it to the server using our StreamableClientTransport. // Check that all requests honor a custom client. @@ -132,7 +132,12 @@ func TestStreamableTransports(t *testing.T) { if err != nil { t.Fatalf("client.Connect() failed: %v", err) } - defer session.Close() + t.Cleanup(func() { + err := session.Close() + if err != nil { + t.Errorf("session.Close() failed: %v", err) + } + }) sid := session.ID() if sid == "" { t.Fatalf("empty session ID") @@ -222,7 +227,7 @@ func TestStreamableServerShutdown(t *testing.T) { httpServer := httptest.NewUnstartedServer(handler) httpServer.Config.RegisterOnShutdown(func() { for session := range server.Sessions() { - session.Close() + _ = session.Close() } }) httpServer.Start() @@ -432,10 +437,13 @@ func TestServerTransportCleanup(t *testing.T) { }, }) - handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) - handler.onTransportDeletion = func(sessionID string) { - chans[sessionID] <- struct{}{} - } + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{ + OnSessionClose: func(sessionID string) { + chans[sessionID] <- struct{}{} + }, + }, + ) httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() @@ -1490,6 +1498,44 @@ func TestStreamableClientContextPropagation(t *testing.T) { } +// TestStreamableHTTPHandler_OnSessionClose_SessionDeletion tests that the +// OnSessionClose callback is called when the client closes the session. +func TestStreamableHTTPHandler_OnSessionClose_SessionDeletion(t *testing.T) { + var closedSessions []string + + server := NewServer(testImpl, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + OnSessionClose: func(sessionID string) { + closedSessions = append(closedSessions, sessionID) + }, + }) + + httpServer := httptest.NewServer(handler) + t.Cleanup(httpServer.Close) + + ctx := context.Background() + client := NewClient(testImpl, nil) + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + + sessionID := session.ID() + t.Log("Closing client session") + err = session.Close() + if err != nil { + t.Fatalf("session.Close() failed: %v", err) + } + + if len(closedSessions) != 1 { + t.Fatalf("got %d closed sessions, want 1", len(closedSessions)) + } + if closedSessions[0] != sessionID { + t.Fatalf("got session ID %q, want %q", closedSessions[0], sessionID) + } +} + // mustNotPanic is a helper to enforce that test handlers do not panic (see // issue #556). func mustNotPanic(t *testing.T, h http.Handler) http.Handler {