Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ type StreamableHTTPHandler struct {
getServer func(*http.Request) *Server
opts StreamableHTTPOptions

onTransportDeletion func(sessionID string) // for testing only

mu sync.Mutex
// TODO: we should store the ServerSession along with the transport, because
// we need to cancel keepalive requests when closing the transport.
mu sync.Mutex
transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
}

Expand All @@ -72,6 +68,11 @@ type StreamableHTTPOptions struct {
// Logger specifies the logger to use.
// If nil, do not log.
Logger *slog.Logger

// OnSessionClose is a callback function that is invoked when a [ServerSession]
// is closed. This happens when a session is ended explicitly by the MCP client
// or when it is interrupted due to a timeout or other errors.
OnSessionClose func(sessionID string)
}

// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
Expand Down Expand Up @@ -163,7 +164,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
transport.connection.Close()
// TODO: consider logging this error
_ = transport.session.Close()
}
w.WriteHeader(http.StatusNoContent)
return
Expand Down Expand Up @@ -297,8 +299,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
if h.onTransportDeletion != nil {
h.onTransportDeletion(transport.SessionID)
if h.opts.OnSessionClose != nil {
h.opts.OnSessionClose(transport.SessionID)
}
},
}
Expand All @@ -318,6 +320,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
} else {
// Otherwise, save the transport so that it can be reused
h.mu.Lock()
transport.session = ss
h.transports[transport.SessionID] = transport
h.mu.Unlock()
}
Expand Down Expand Up @@ -386,6 +389,9 @@ type StreamableServerTransport struct {

// connection is non-nil if and only if the transport has been connected.
connection *streamableServerConn

// the server session associated with this transport.
session *ServerSession
}

// Connect implements the [Transport] interface.
Expand Down
60 changes: 53 additions & 7 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func TestStreamableTransports(t *testing.T) {
}
handler.ServeHTTP(w, r)
})))
defer httpServer.Close()
t.Cleanup(func() { httpServer.Close() })

// Create a client and connect it to the server using our StreamableClientTransport.
// Check that all requests honor a custom client.
Expand All @@ -132,7 +132,12 @@ func TestStreamableTransports(t *testing.T) {
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer session.Close()
t.Cleanup(func() {
err := session.Close()
if err != nil {
t.Errorf("session.Close() failed: %v", err)
}
})
sid := session.ID()
if sid == "" {
t.Fatalf("empty session ID")
Expand Down Expand Up @@ -222,7 +227,7 @@ func TestStreamableServerShutdown(t *testing.T) {
httpServer := httptest.NewUnstartedServer(handler)
httpServer.Config.RegisterOnShutdown(func() {
for session := range server.Sessions() {
session.Close()
_ = session.Close()
}
})
httpServer.Start()
Expand Down Expand Up @@ -432,10 +437,13 @@ func TestServerTransportCleanup(t *testing.T) {
},
})

handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)
handler.onTransportDeletion = func(sessionID string) {
chans[sessionID] <- struct{}{}
}
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server },
&StreamableHTTPOptions{
OnSessionClose: func(sessionID string) {
chans[sessionID] <- struct{}{}
},
},
)

httpServer := httptest.NewServer(mustNotPanic(t, handler))
defer httpServer.Close()
Expand Down Expand Up @@ -1490,6 +1498,44 @@ func TestStreamableClientContextPropagation(t *testing.T) {

}

// TestStreamableHTTPHandler_OnSessionClose_SessionDeletion tests that the
// OnSessionClose callback is called when the client closes the session.
func TestStreamableHTTPHandler_OnSessionClose_SessionDeletion(t *testing.T) {
var closedSessions []string

server := NewServer(testImpl, nil)
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{
OnSessionClose: func(sessionID string) {
closedSessions = append(closedSessions, sessionID)
},
})

httpServer := httptest.NewServer(handler)
t.Cleanup(httpServer.Close)

ctx := context.Background()
client := NewClient(testImpl, nil)
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}

sessionID := session.ID()
t.Log("Closing client session")
err = session.Close()
if err != nil {
t.Fatalf("session.Close() failed: %v", err)
}

if len(closedSessions) != 1 {
t.Fatalf("got %d closed sessions, want 1", len(closedSessions))
}
if closedSessions[0] != sessionID {
t.Fatalf("got session ID %q, want %q", closedSessions[0], sessionID)
}
}

// mustNotPanic is a helper to enforce that test handlers do not panic (see
// issue #556).
func mustNotPanic(t *testing.T, h http.Handler) http.Handler {
Expand Down