Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,12 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
ctx: connCtx,
cancel: cancel,
}
// Start the persistent SSE listener right away.
// Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint.
// This can be used to open an SSE stream, allowing the server to
// communicate to the client, without the client first sending data via HTTP POST.
go conn.handleSSE(nil, true)

return conn, nil
}

Expand Down Expand Up @@ -859,7 +865,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
switch ct := resp.Header.Get("Content-Type"); ct {
case "text/event-stream":
// Section 2.1: The SSE stream is initiated after a POST.
go s.handleSSE(resp)
go s.handleSSE(resp, false)
case "application/json":
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
Expand All @@ -879,13 +885,11 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
return sessionID, nil
}

// handleSSE manages the entire lifecycle of an SSE connection. It processes
// an incoming Server-Sent Events stream and automatically handles reconnection
// logic if the stream breaks.
func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
// handleSSE manages the lifecycle of an SSE connection. It can be either
// persistent (for the main GET listener) or temporary (for a POST response).
func (s *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) {
resp := initialResp
var lastEventID string

for {
eventID, clientClosed := s.processStream(resp)
lastEventID = eventID
Expand All @@ -894,6 +898,11 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
if clientClosed {
return
}
// If the stream has ended, then do not reconnect if the stream is
// temporary (POST initiated SSE).
if lastEventID == "" && !persistent {
return
}

// The stream was interrupted or ended by the server. Attempt to reconnect.
newResp, err := s.reconnect(lastEventID)
Expand All @@ -915,9 +924,13 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
// processStream reads from a single response body, sending events to the
// incoming channel. It returns the ID of the last processed event, any error
// that occurred, and a flag indicating if the connection was closed by the client.
// If resp is nil, it returns "", false.
func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) {
defer resp.Body.Close()
if resp == nil {
return "", false
}

defer resp.Body.Close()
for evt, err := range scanEvents(resp.Body) {
if err != nil {
return lastEventID, false
Expand All @@ -931,13 +944,11 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s
case s.incoming <- evt.Data:
case <-s.done:
// The connection was closed by the client; exit gracefully.
return lastEventID, true
return "", true
}
}

// The loop finished without an error, indicating the server closed the stream.
// We'll attempt to reconnect, so this is not a client-side close.
return lastEventID, false
return "", false
}

// reconnect handles the logic of retrying a connection with an exponential
Expand Down
39 changes: 35 additions & 4 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func TestClientReplay(t *testing.T) {
clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"})

// 4. Read and verify messages until the server signals it's ready for the proxy kill.
receivedNotifications := readProgressNotifications(t, ctx, notifications, 2)
receivedNotifications := readNotifications(t, ctx, notifications, 2)

wantReceived := []string{"msg1", "msg2"}
if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" {
Expand Down Expand Up @@ -201,7 +201,7 @@ func TestClientReplay(t *testing.T) {
// 7. Continue reading from the same connection object.
// Its internal logic should successfully retry, reconnect to the new proxy,
// and receive the replayed messages.
recoveredNotifications := readProgressNotifications(t, ctx, notifications, 2)
recoveredNotifications := readNotifications(t, ctx, notifications, 2)

// 8. Verify the correct messages were received on the recovered connection.
wantRecovered := []string{"msg3", "msg4"}
Expand All @@ -211,8 +211,39 @@ func TestClientReplay(t *testing.T) {
}
}

// Helper to read a specific number of progress notifications.
func readProgressNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string {
// TestServerInitiatedSSE verifies that the persistent SSE connection remains
// open and can receive server-initiated events.
func TestServerInitiatedSSE(t *testing.T) {
notifications := make(chan string)
server := NewServer(testImpl, nil)

httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))
defer httpServer.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client := NewClient(testImpl, &ClientOptions{ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) {
notifications <- "toolListChanged"
},
})
clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil))
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer clientSession.Close()
server.AddTool(&Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}},
func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) {
return &CallToolResult{}, nil
})
receivedNotifications := readNotifications(t, ctx, notifications, 1)
wantReceived := []string{"toolListChanged"}
if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" {
t.Errorf("Received notifications mismatch (-want +got):\n%s", diff)
}
}

// Helper to read a specific number of notifications.
func readNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string {
t.Helper()
var collectedNotifications []string
for {
Expand Down
Loading