Skip to content

Commit 6bea1d4

Browse files
authored
fix: add session idle TTL sweeper to prevent transport state leak (#724)
* fix: add session idle TTL sweeper to prevent transport state leak When clients disconnect without sending DELETE, per-session transport state (tools, resources, resource templates, log levels, request IDs) is never cleaned up, causing a memory leak. Add WithSessionIdleTTL option that enables a background sweeper to periodically remove stale entries from all per-session stores. Extract cleanupSessionState from handleDelete for shared use. * fix: call SessionIdManager.Terminate in sweeper to invalidate swept sessions Without this, stateful session ID managers still consider swept sessions as valid — clients can keep sending requests with the old session ID. Also fix two test issues flagged by CodeRabbit: - remove vacuous sessionRequestIDs assertion (never populated in POST-only test) - extend ping duration in "active sessions are not swept" so the sweeper actually ticks during the test * fix: avoid passing nil request to SessionIdManagerResolver in sweeper * test: add non-vacuous sessionTools assertion to sweeper test
1 parent 0510f0c commit 6bea1d4

File tree

2 files changed

+271
-10
lines changed

2 files changed

+271
-10
lines changed

server/streamable_http.go

Lines changed: 130 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,19 @@ func WithTLSCert(certFile, keyFile string) StreamableHTTPOption {
144144
}
145145
}
146146

147+
// WithSessionIdleTTL sets the idle TTL for per-session transport state.
148+
// When enabled, a background sweeper periodically removes entries from
149+
// per-session stores (tools, resources, resource templates, log levels,
150+
// request IDs) for sessions that have been idle longer than the given
151+
// duration. This prevents memory leaks when clients disconnect without
152+
// sending a DELETE request. A zero or negative value disables the sweeper
153+
// (the default).
154+
func WithSessionIdleTTL(ttl time.Duration) StreamableHTTPOption {
155+
return func(s *StreamableHTTPServer) {
156+
s.sessionIdleTTL = ttl
157+
}
158+
}
159+
147160
// StreamableHTTPServer implements a Streamable-http based MCP server.
148161
// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
149162
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
@@ -181,13 +194,18 @@ type StreamableHTTPServer struct {
181194
endpointPath string
182195
contextFunc HTTPContextFunc
183196
sessionIdManagerResolver SessionIdManagerResolver
197+
sessionIdManager SessionIdManager // for non-request contexts (sweeper)
184198
listenHeartbeatInterval time.Duration
185199
logger util.Logger
186200
sessionLogLevels *sessionLogLevelsStore
187201
disableStreaming bool
188202

189203
tlsCertFile string
190204
tlsKeyFile string
205+
206+
sessionIdleTTL time.Duration
207+
sessionLastActive sync.Map // sessionID → *atomic.Int64 (unix nanos)
208+
sweeperCancel context.CancelFunc
191209
}
192210

193211
// NewStreamableHTTPServer creates a new streamable-http server instance
@@ -207,6 +225,20 @@ func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *S
207225
for _, opt := range opts {
208226
opt(s)
209227
}
228+
229+
// Cache the session ID manager for use in non-request contexts (sweeper).
230+
// DefaultSessionIdManagerResolver always returns the same manager,
231+
// so resolving it once at startup is semantically identical.
232+
if r, ok := s.sessionIdManagerResolver.(*DefaultSessionIdManagerResolver); ok {
233+
s.sessionIdManager = r.manager
234+
}
235+
236+
if s.sessionIdleTTL > 0 {
237+
ctx, cancel := context.WithCancel(context.Background())
238+
s.sweeperCancel = cancel
239+
s.startSessionSweeper(ctx)
240+
}
241+
210242
return s
211243
}
212244

@@ -266,6 +298,9 @@ func (s *StreamableHTTPServer) Start(addr string) error {
266298
// Shutdown gracefully stops the server, closing all active sessions
267299
// and shutting down the HTTP server.
268300
func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
301+
if s.sweeperCancel != nil {
302+
s.sweeperCancel()
303+
}
269304

270305
// shutdown the server if needed (may use as a http.Handler)
271306
s.mu.RLock()
@@ -354,6 +389,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
354389
}
355390
}
356391

392+
s.touchSession(sessionID)
393+
357394
// For non-initialize requests, try to reuse existing registered session
358395
var session *streamableHttpSession
359396
if !isInitializeRequest {
@@ -557,6 +594,8 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
557594
defer s.activeSessions.Delete(sessionID)
558595
}
559596

597+
s.touchSession(sessionID)
598+
560599
// Set the client context before handling the message
561600
w.Header().Set("Content-Type", "text/event-stream")
562601
w.Header().Set("Cache-Control", "no-cache")
@@ -671,6 +710,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
671710
return
672711
}
673712
flusher.Flush()
713+
s.touchSession(sessionID)
674714
case <-r.Context().Done():
675715
return
676716
}
@@ -691,15 +731,7 @@ func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Reque
691731
return
692732
}
693733

694-
// remove the session relateddata from the sessionToolsStore
695-
s.sessionTools.delete(sessionID)
696-
s.sessionResources.delete(sessionID)
697-
s.sessionResourceTemplates.delete(sessionID)
698-
s.sessionLogLevels.delete(sessionID)
699-
// remove current session's requstID information
700-
s.sessionRequestIDs.Delete(sessionID)
701-
s.activeSessions.Delete(sessionID)
702-
s.server.UnregisterSession(r.Context(), sessionID)
734+
s.cleanupSessionState(r.Context(), sessionID)
703735

704736
w.WriteHeader(http.StatusOK)
705737
}
@@ -843,6 +875,92 @@ func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 {
843875
return counter.Add(1)
844876
}
845877

878+
// touchSession records the current time as the last activity for the given session.
879+
// It is a no-op when the sweeper is disabled (sessionIdleTTL <= 0) or sessionID is empty.
880+
func (s *StreamableHTTPServer) touchSession(sessionID string) {
881+
if sessionID == "" || s.sessionIdleTTL <= 0 {
882+
return
883+
}
884+
now := time.Now().UnixNano()
885+
actual, _ := s.sessionLastActive.LoadOrStore(sessionID, new(atomic.Int64))
886+
actual.(*atomic.Int64).Store(now)
887+
}
888+
889+
// cleanupSessionState removes all per-session transport state for the given session ID.
890+
func (s *StreamableHTTPServer) cleanupSessionState(ctx context.Context, sessionID string) {
891+
// Unregister first to stop notification routing before deleting data.
892+
s.server.UnregisterSession(ctx, sessionID)
893+
s.activeSessions.Delete(sessionID)
894+
s.sessionTools.delete(sessionID)
895+
s.sessionResources.delete(sessionID)
896+
s.sessionResourceTemplates.delete(sessionID)
897+
s.sessionLogLevels.delete(sessionID)
898+
s.sessionRequestIDs.Delete(sessionID)
899+
s.sessionLastActive.Delete(sessionID)
900+
}
901+
902+
// startSessionSweeper launches a background goroutine that periodically removes
903+
// transport state for sessions that have been idle longer than sessionIdleTTL.
904+
func (s *StreamableHTTPServer) startSessionSweeper(ctx context.Context) {
905+
interval := max(s.sessionIdleTTL/2, time.Second)
906+
907+
go func() {
908+
ticker := time.NewTicker(interval)
909+
defer ticker.Stop()
910+
911+
for {
912+
select {
913+
case <-ctx.Done():
914+
return
915+
case <-ticker.C:
916+
s.sweepExpiredSessions()
917+
}
918+
}
919+
}()
920+
}
921+
922+
// sweepExpiredSessions iterates all tracked sessions and cleans up those
923+
// whose last activity exceeds sessionIdleTTL.
924+
func (s *StreamableHTTPServer) sweepExpiredSessions() {
925+
now := time.Now().UnixNano()
926+
ttlNanos := s.sessionIdleTTL.Nanoseconds()
927+
928+
s.sessionLastActive.Range(func(key, value any) bool {
929+
sessionID, ok := key.(string)
930+
if !ok {
931+
s.sessionLastActive.Delete(key)
932+
return true
933+
}
934+
lastActive, ok := value.(*atomic.Int64)
935+
if !ok {
936+
s.sessionLastActive.Delete(key)
937+
return true
938+
}
939+
940+
capturedLastActive := lastActive.Load()
941+
if now-capturedLastActive < ttlNanos {
942+
return true
943+
}
944+
945+
// Re-check: if lastActive changed since we read it, the session
946+
// was touched concurrently — skip it. A small TOCTOU window
947+
// remains between this check and cleanup, but it is acceptable
948+
// for a distributed best-effort sweeper.
949+
if lastActive.Load() != capturedLastActive {
950+
return true
951+
}
952+
953+
s.logger.Infof("Sweeping expired session: %s", sessionID)
954+
mgr := s.sessionIdManager
955+
if mgr == nil {
956+
mgr = s.sessionIdManagerResolver.ResolveSessionIdManager(nil)
957+
}
958+
_, _ = mgr.Terminate(sessionID)
959+
s.cleanupSessionState(context.Background(), sessionID)
960+
return true
961+
})
962+
}
963+
846964
// --- session ---
847965
type sessionLogLevelsStore struct {
848966
mu sync.RWMutex
@@ -1286,7 +1404,9 @@ var _ SessionWithRoots = (*streamableHttpSession)(nil)
12861404

12871405
// --- session id manager ---
12881406

1289-
// SessionIdManagerResolver resolves a SessionIdManager based on the HTTP request
1407+
// SessionIdManagerResolver resolves a SessionIdManager based on the HTTP request.
1408+
// Implementations must handle a nil r, which may be passed from non-request
1409+
// contexts such as the session idle TTL sweeper.
12901410
type SessionIdManagerResolver interface {
12911411
ResolveSessionIdManager(r *http.Request) SessionIdManager
12921412
}

server/streamable_http_test.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2612,3 +2612,144 @@ func TestStreamableHTTP_DrainNotifications(t *testing.T) {
26122612
_ = drainLoopCalled
26132613
})
26142614
}
2615+
2616+
func TestStreamableHTTP_SessionIdleTTLSweeper(t *testing.T) {
2617+
t.Run("expired sessions are swept", func(t *testing.T) {
2618+
mcpServer := NewMCPServer("test", "1.0")
2619+
httpServer := NewStreamableHTTPServer(mcpServer,
2620+
WithStateful(true),
2621+
WithSessionIdleTTL(100*time.Millisecond),
2622+
)
2623+
defer func() { _ = httpServer.Shutdown(context.Background()) }()
2624+
ts := httptest.NewServer(httpServer)
2625+
defer ts.Close()
2626+
2627+
// Initialize a session
2628+
resp, err := postJSON(ts.URL, initRequest)
2629+
require.NoError(t, err)
2630+
defer resp.Body.Close()
2631+
require.Equal(t, http.StatusOK, resp.StatusCode)
2632+
sessionID := resp.Header.Get(HeaderKeySessionID)
2633+
require.NotEmpty(t, sessionID)
2634+
2635+
// Verify the session has transport state
2636+
_, hasActivity := httpServer.sessionLastActive.Load(sessionID)
2637+
assert.True(t, hasActivity, "session should be tracked after initialize")
2638+
2639+
// Populate sessionTools so we can verify cleanup is not vacuously true
2640+
httpServer.sessionTools.set(sessionID, map[string]ServerTool{"test": {}})
2641+
2642+
// Wait for the sweeper to clean the expired session
2643+
assert.Eventually(t, func() bool {
2644+
_, exists := httpServer.sessionLastActive.Load(sessionID)
2645+
return !exists
2646+
}, 2*time.Second, 50*time.Millisecond, "session should be swept after idle TTL")
2647+
2648+
// Verify per-session transport state is cleaned
2649+
_, hasActiveSession := httpServer.activeSessions.Load(sessionID)
2650+
assert.False(t, hasActiveSession, "activeSessions should be cleaned")
2651+
tools := httpServer.sessionTools.get(sessionID)
2652+
assert.Empty(t, tools, "sessionTools should be cleaned")
2653+
})
2654+
2655+
t.Run("active sessions are not swept", func(t *testing.T) {
2656+
mcpServer := NewMCPServer("test", "1.0")
2657+
httpServer := NewStreamableHTTPServer(mcpServer,
2658+
WithStateful(true),
2659+
WithSessionIdleTTL(200*time.Millisecond),
2660+
)
2661+
defer func() { _ = httpServer.Shutdown(context.Background()) }()
2662+
ts := httptest.NewServer(httpServer)
2663+
defer ts.Close()
2664+
2665+
// Initialize a session
2666+
resp, err := postJSON(ts.URL, initRequest)
2667+
require.NoError(t, err)
2668+
defer resp.Body.Close()
2669+
require.Equal(t, http.StatusOK, resp.StatusCode)
2670+
sessionID := resp.Header.Get(HeaderKeySessionID)
2671+
require.NotEmpty(t, sessionID)
2672+
2673+
// TTL=200ms → sweep interval = max(100ms, 1s) = 1s.
2674+
// Ping for 2s so at least one sweep fires while the session is still active.
2675+
deadline := time.Now().Add(2 * time.Second)
2676+
for time.Now().Before(deadline) {
2677+
time.Sleep(80 * time.Millisecond)
2678+
pingResp, pingErr := postSessionJSON(ts.URL, sessionID, map[string]any{
2679+
"jsonrpc": "2.0",
2680+
"id": 1,
2681+
"method": "ping",
2682+
"params": map[string]any{},
2683+
})
2684+
require.NoError(t, pingErr)
2685+
pingResp.Body.Close()
2686+
}
2687+
2688+
// Session should still be tracked — the sweeper fired but the session
2689+
// was touched within TTL each time.
2690+
_, hasActivity := httpServer.sessionLastActive.Load(sessionID)
2691+
assert.True(t, hasActivity, "active session should not be swept")
2692+
})
2693+
2694+
t.Run("disabled sweeper does not clean sessions", func(t *testing.T) {
2695+
mcpServer := NewMCPServer("test", "1.0")
2696+
httpServer := NewStreamableHTTPServer(mcpServer,
2697+
WithStateful(true),
2698+
// No WithSessionIdleTTL — sweeper disabled
2699+
)
2700+
defer func() { _ = httpServer.Shutdown(context.Background()) }()
2701+
ts := httptest.NewServer(httpServer)
2702+
defer ts.Close()
2703+
2704+
// Initialize a session
2705+
resp, err := postJSON(ts.URL, initRequest)
2706+
require.NoError(t, err)
2707+
defer resp.Body.Close()
2708+
require.Equal(t, http.StatusOK, resp.StatusCode)
2709+
sessionID := resp.Header.Get(HeaderKeySessionID)
2710+
require.NotEmpty(t, sessionID)
2711+
2712+
// sessionLastActive should NOT be populated when sweeper is disabled
2713+
_, hasActivity := httpServer.sessionLastActive.Load(sessionID)
2714+
assert.False(t, hasActivity, "touchSession should be a no-op when sweeper is disabled")
2715+
2716+
// Session should still be in activeSessions (not cleaned)
2717+
time.Sleep(200 * time.Millisecond)
2718+
_, hasActiveSession := httpServer.activeSessions.Load(sessionID)
2719+
assert.True(t, hasActiveSession, "session should remain in activeSessions")
2720+
})
2721+
2722+
t.Run("DELETE still cleans session when sweeper is enabled", func(t *testing.T) {
2723+
mcpServer := NewMCPServer("test", "1.0")
2724+
httpServer := NewStreamableHTTPServer(mcpServer,
2725+
WithStateful(true),
2726+
WithSessionIdleTTL(10*time.Second), // long TTL so sweeper won't fire
2727+
)
2728+
defer func() { _ = httpServer.Shutdown(context.Background()) }()
2729+
ts := httptest.NewServer(httpServer)
2730+
defer ts.Close()
2731+
2732+
// Initialize a session
2733+
resp, err := postJSON(ts.URL, initRequest)
2734+
require.NoError(t, err)
2735+
defer resp.Body.Close()
2736+
require.Equal(t, http.StatusOK, resp.StatusCode)
2737+
sessionID := resp.Header.Get(HeaderKeySessionID)
2738+
require.NotEmpty(t, sessionID)
2739+
2740+
// Send DELETE
2741+
req, err := http.NewRequest(http.MethodDelete, ts.URL, nil)
2742+
require.NoError(t, err)
2743+
req.Header.Set(HeaderKeySessionID, sessionID)
2744+
delResp, err := http.DefaultClient.Do(req)
2745+
require.NoError(t, err)
2746+
delResp.Body.Close()
2747+
require.Equal(t, http.StatusOK, delResp.StatusCode)
2748+
2749+
// Verify all state is cleaned immediately
2750+
_, hasActivity := httpServer.sessionLastActive.Load(sessionID)
2751+
assert.False(t, hasActivity, "sessionLastActive should be cleaned after DELETE")
2752+
_, hasActiveSession := httpServer.activeSessions.Load(sessionID)
2753+
assert.False(t, hasActiveSession, "activeSessions should be cleaned after DELETE")
2754+
})
2755+
}

0 commit comments

Comments
 (0)