Skip to content

Commit 1dd9c32

Browse files
committed
mcp/streamable: add persistent SSE GET listener
This CL adds the optional persistent SSE GET listener as specified in section 2.2. https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server This enables server initiated SSE streams.
1 parent 3ac4ca9 commit 1dd9c32

File tree

2 files changed

+58
-15
lines changed

2 files changed

+58
-15
lines changed

mcp/streamable.go

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,12 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
735735
ctx: connCtx,
736736
cancel: cancel,
737737
}
738+
// Start the persistent SSE listener right away.
739+
// Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint.
740+
// This can be used to open an SSE stream, allowing the server to
741+
// communicate to the client, without the client first sending data via HTTP POST.
742+
go conn.handleSSE(nil, false)
743+
738744
return conn, nil
739745
}
740746

@@ -859,7 +865,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
859865
switch ct := resp.Header.Get("Content-Type"); ct {
860866
case "text/event-stream":
861867
// Section 2.1: The SSE stream is initiated after a POST.
862-
go s.handleSSE(resp)
868+
go s.handleSSE(resp, true)
863869
case "application/json":
864870
body, err := io.ReadAll(resp.Body)
865871
resp.Body.Close()
@@ -879,13 +885,11 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
879885
return sessionID, nil
880886
}
881887

882-
// handleSSE manages the entire lifecycle of an SSE connection. It processes
883-
// an incoming Server-Sent Events stream and automatically handles reconnection
884-
// logic if the stream breaks.
885-
func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
888+
// handleSSE manages the lifecycle of an SSE connection. It can be either
889+
// temporary (for a POST response) or persistent (for the main GET listener).
890+
func (s *streamableClientConn) handleSSE(initialResp *http.Response, temporary bool) {
886891
resp := initialResp
887892
var lastEventID string
888-
889893
for {
890894
eventID, clientClosed := s.processStream(resp)
891895
lastEventID = eventID
@@ -894,6 +898,11 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
894898
if clientClosed {
895899
return
896900
}
901+
// If the stream has ended, then do not reconnect if the stream is
902+
// temporary (POST initiated SSE).
903+
if lastEventID == "" && temporary {
904+
return
905+
}
897906

898907
// The stream was interrupted or ended by the server. Attempt to reconnect.
899908
newResp, err := s.reconnect(lastEventID)
@@ -915,9 +924,13 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
915924
// processStream reads from a single response body, sending events to the
916925
// incoming channel. It returns the ID of the last processed event, any error
917926
// that occurred, and a flag indicating if the connection was closed by the client.
927+
// If resp is nil, it returns "", false.
918928
func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) {
919-
defer resp.Body.Close()
929+
if resp == nil {
930+
return "", false
931+
}
920932

933+
defer resp.Body.Close()
921934
for evt, err := range scanEvents(resp.Body) {
922935
if err != nil {
923936
return lastEventID, false
@@ -931,13 +944,11 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s
931944
case s.incoming <- evt.Data:
932945
case <-s.done:
933946
// The connection was closed by the client; exit gracefully.
934-
return lastEventID, true
947+
return "", true
935948
}
936949
}
937-
938950
// The loop finished without an error, indicating the server closed the stream.
939-
// We'll attempt to reconnect, so this is not a client-side close.
940-
return lastEventID, false
951+
return "", false
941952
}
942953

943954
// reconnect handles the logic of retrying a connection with an exponential

mcp/streamable_test.go

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func TestClientReplay(t *testing.T) {
168168
clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"})
169169

170170
// 4. Read and verify messages until the server signals it's ready for the proxy kill.
171-
receivedNotifications := readProgressNotifications(t, ctx, notifications, 2)
171+
receivedNotifications := readNotifications(t, ctx, notifications, 2)
172172

173173
wantReceived := []string{"msg1", "msg2"}
174174
if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" {
@@ -201,7 +201,7 @@ func TestClientReplay(t *testing.T) {
201201
// 7. Continue reading from the same connection object.
202202
// Its internal logic should successfully retry, reconnect to the new proxy,
203203
// and receive the replayed messages.
204-
recoveredNotifications := readProgressNotifications(t, ctx, notifications, 2)
204+
recoveredNotifications := readNotifications(t, ctx, notifications, 2)
205205

206206
// 8. Verify the correct messages were received on the recovered connection.
207207
wantRecovered := []string{"msg3", "msg4"}
@@ -211,8 +211,40 @@ func TestClientReplay(t *testing.T) {
211211
}
212212
}
213213

214-
// Helper to read a specific number of progress notifications.
215-
func readProgressNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string {
214+
// TestServerInitiatedSSE verifies that the persistent SSE connection remains
215+
// open and can receive multiple, non-consecutive, server-initiated events.
216+
func TestServerInitiatedSSE(t *testing.T) {
217+
notifications := make(chan string)
218+
server := NewServer(testImpl, &ServerOptions{})
219+
220+
httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))
221+
defer httpServer.Close()
222+
223+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
224+
defer cancel()
225+
client := NewClient(testImpl, &ClientOptions{ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) {
226+
notifications <- "toolListChanged"
227+
},
228+
})
229+
clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil))
230+
if err != nil {
231+
t.Fatalf("client.Connect() failed: %v", err)
232+
}
233+
defer clientSession.Close()
234+
time.Sleep(50 * time.Millisecond)
235+
server.AddTool(&Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}},
236+
func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) {
237+
return &CallToolResult{}, nil
238+
})
239+
receivedNotifications := readNotifications(t, ctx, notifications, 1)
240+
wantReceived := []string{"toolListChanged"}
241+
if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" {
242+
t.Errorf("Received notifications mismatch (-want +got):\n%s", diff)
243+
}
244+
}
245+
246+
// Helper to read a specific number of notifications.
247+
func readNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string {
216248
t.Helper()
217249
var collectedNotifications []string
218250
for {

0 commit comments

Comments
 (0)