Skip to content

Commit a9a503f

Browse files
mcp/streamable: add persistent SSE GET listener (#206)
This CL adds the optional persistent SSE GET listener as specified in section 2.2. https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server This enables server initiated SSE streams. For #10
1 parent 56734ed commit a9a503f

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

mcp/streamable.go

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,12 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
743743
ctx: connCtx,
744744
cancel: cancel,
745745
}
746+
// Start the persistent SSE listener right away.
747+
// Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint.
748+
// This can be used to open an SSE stream, allowing the server to
749+
// communicate to the client, without the client first sending data via HTTP POST.
750+
go conn.handleSSE(nil, true)
751+
746752
return conn, nil
747753
}
748754

@@ -867,7 +873,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
867873
switch ct := resp.Header.Get("Content-Type"); ct {
868874
case "text/event-stream":
869875
// Section 2.1: The SSE stream is initiated after a POST.
870-
go s.handleSSE(resp)
876+
go s.handleSSE(resp, false)
871877
case "application/json":
872878
body, err := io.ReadAll(resp.Body)
873879
resp.Body.Close()
@@ -887,13 +893,11 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
887893
return sessionID, nil
888894
}
889895

890-
// handleSSE manages the entire lifecycle of an SSE connection. It processes
891-
// an incoming Server-Sent Events stream and automatically handles reconnection
892-
// logic if the stream breaks.
893-
func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
896+
// handleSSE manages the lifecycle of an SSE connection. It can be either
897+
// persistent (for the main GET listener) or temporary (for a POST response).
898+
func (s *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) {
894899
resp := initialResp
895900
var lastEventID string
896-
897901
for {
898902
eventID, clientClosed := s.processStream(resp)
899903
lastEventID = eventID
@@ -902,6 +906,11 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
902906
if clientClosed {
903907
return
904908
}
909+
// If the stream has ended, then do not reconnect if the stream is
910+
// temporary (POST initiated SSE).
911+
if lastEventID == "" && !persistent {
912+
return
913+
}
905914

906915
// The stream was interrupted or ended by the server. Attempt to reconnect.
907916
newResp, err := s.reconnect(lastEventID)
@@ -923,9 +932,13 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
923932
// processStream reads from a single response body, sending events to the
924933
// incoming channel. It returns the ID of the last processed event, any error
925934
// that occurred, and a flag indicating if the connection was closed by the client.
935+
// If resp is nil, it returns "", false.
926936
func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) {
927-
defer resp.Body.Close()
937+
if resp == nil {
938+
return "", false
939+
}
928940

941+
defer resp.Body.Close()
929942
for evt, err := range scanEvents(resp.Body) {
930943
if err != nil {
931944
return lastEventID, false
@@ -939,13 +952,11 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s
939952
case s.incoming <- evt.Data:
940953
case <-s.done:
941954
// The connection was closed by the client; exit gracefully.
942-
return lastEventID, true
955+
return "", true
943956
}
944957
}
945-
946958
// The loop finished without an error, indicating the server closed the stream.
947-
// We'll attempt to reconnect, so this is not a client-side close.
948-
return lastEventID, false
959+
return "", false
949960
}
950961

951962
// reconnect handles the logic of retrying a connection with an exponential

mcp/streamable_test.go

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func TestClientReplay(t *testing.T) {
168168
clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"})
169169

170170
// 4. Read and verify messages until the server signals it's ready for the proxy kill.
171-
receivedNotifications := readProgressNotifications(t, ctx, notifications, 2)
171+
receivedNotifications := readNotifications(t, ctx, notifications, 2)
172172

173173
wantReceived := []string{"msg1", "msg2"}
174174
if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" {
@@ -201,7 +201,7 @@ func TestClientReplay(t *testing.T) {
201201
// 7. Continue reading from the same connection object.
202202
// Its internal logic should successfully retry, reconnect to the new proxy,
203203
// and receive the replayed messages.
204-
recoveredNotifications := readProgressNotifications(t, ctx, notifications, 2)
204+
recoveredNotifications := readNotifications(t, ctx, notifications, 2)
205205

206206
// 8. Verify the correct messages were received on the recovered connection.
207207
wantRecovered := []string{"msg3", "msg4"}
@@ -211,8 +211,39 @@ func TestClientReplay(t *testing.T) {
211211
}
212212
}
213213

214-
// Helper to read a specific number of progress notifications.
215-
func readProgressNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string {
214+
// TestServerInitiatedSSE verifies that the persistent SSE connection remains
215+
// open and can receive server-initiated events.
216+
func TestServerInitiatedSSE(t *testing.T) {
217+
notifications := make(chan string)
218+
server := NewServer(testImpl, nil)
219+
220+
httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))
221+
defer httpServer.Close()
222+
223+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
224+
defer cancel()
225+
client := NewClient(testImpl, &ClientOptions{ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) {
226+
notifications <- "toolListChanged"
227+
},
228+
})
229+
clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil))
230+
if err != nil {
231+
t.Fatalf("client.Connect() failed: %v", err)
232+
}
233+
defer clientSession.Close()
234+
server.AddTool(&Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}},
235+
func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) {
236+
return &CallToolResult{}, nil
237+
})
238+
receivedNotifications := readNotifications(t, ctx, notifications, 1)
239+
wantReceived := []string{"toolListChanged"}
240+
if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" {
241+
t.Errorf("Received notifications mismatch (-want +got):\n%s", diff)
242+
}
243+
}
244+
245+
// Helper to read a specific number of notifications.
246+
func readNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string {
216247
t.Helper()
217248
var collectedNotifications []string
218249
for {

0 commit comments

Comments
 (0)