Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ type StreamableHTTPHandler struct {
getServer func(*http.Request) *Server
opts StreamableHTTPOptions

onTransportDeletion func(sessionID string) // for testing only

mu sync.Mutex
// TODO: we should store the ServerSession along with the transport, because
// we need to cancel keepalive requests when closing the transport.
mu sync.Mutex
transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
}

Expand Down Expand Up @@ -77,6 +73,11 @@ type StreamableHTTPOptions struct {
//
// [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
JSONResponse bool

// OnSessionClose is a callback function that is invoked when a [ServerSession]
// is closed. This happens when a session is ended explicitly by the MCP client
// or when it is interrupted due to a timeout or other errors.
OnSessionClose func(sessionID string)
}

// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
Expand Down Expand Up @@ -166,7 +167,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
transport.connection.Close()
_ = transport.session.Close()
}
w.WriteHeader(http.StatusNoContent)
return
Expand Down Expand Up @@ -299,8 +300,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
if h.onTransportDeletion != nil {
h.onTransportDeletion(transport.SessionID)
if h.opts.OnSessionClose != nil {
h.opts.OnSessionClose(transport.SessionID)
}
},
}
Expand All @@ -320,6 +321,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
} else {
// Otherwise, save the transport so that it can be reused
h.mu.Lock()
transport.session = ss
h.transports[transport.SessionID] = transport
h.mu.Unlock()
}
Expand Down Expand Up @@ -382,6 +384,9 @@ type StreamableServerTransport struct {

// connection is non-nil if and only if the transport has been connected.
connection *streamableServerConn

// the server session associated with this transport.
session *ServerSession
}

// Connect implements the [Transport] interface.
Expand Down
54 changes: 48 additions & 6 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ func TestStreamableTransports(t *testing.T) {
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer session.Close()
t.Cleanup(func() {
err := session.Close()
if err != nil {
t.Errorf("session.Close() failed: %v", err)
}
})
sid := session.ID()
if sid == "" {
t.Fatalf("empty session ID")
Expand Down Expand Up @@ -220,7 +225,7 @@ func TestStreamableServerShutdown(t *testing.T) {
httpServer := httptest.NewUnstartedServer(handler)
httpServer.Config.RegisterOnShutdown(func() {
for session := range server.Sessions() {
session.Close()
_ = session.Close()
}
})
httpServer.Start()
Expand Down Expand Up @@ -428,12 +433,11 @@ func TestServerTransportCleanup(t *testing.T) {
chans[fmt.Sprint(id)] = make(chan struct{}, 1)
return fmt.Sprint(id)
},
OnSessionClose: func(sessionID string) {
chans[sessionID] <- struct{}{}
},
})

handler.onTransportDeletion = func(sessionID string) {
chans[sessionID] <- struct{}{}
}

httpServer := httptest.NewServer(handler)
defer httpServer.Close()

Expand Down Expand Up @@ -1423,3 +1427,41 @@ func TestStreamableGET(t *testing.T) {
t.Errorf("GET with session ID: got status %d, want %d", got, want)
}
}

// TestStreamableHTTPHandler_OnSessionClose_SessionDeletion tests that the
// OnSessionClose callback is called when the client closes the session.
func TestStreamableHTTPHandler_OnSessionClose_SessionDeletion(t *testing.T) {
var closedSessions []string

server := NewServer(testImpl, nil)
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{
OnSessionClose: func(sessionID string) {
closedSessions = append(closedSessions, sessionID)
},
})

httpServer := httptest.NewServer(handler)
t.Cleanup(httpServer.Close)

ctx := context.Background()
client := NewClient(testImpl, nil)
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}

sessionID := session.ID()
t.Log("Closing client session")
err = session.Close()
if err != nil {
t.Fatalf("session.Close() failed: %v", err)
}

if len(closedSessions) != 1 {
t.Fatalf("got %d closed sessions, want 1", len(closedSessions))
}
if closedSessions[0] != sessionID {
t.Fatalf("got session ID %q, want %q", closedSessions[0], sessionID)
}
}