Skip to content

Commit 2541e8b

Browse files
committed
mcp: establish the streamable client standalone SSE stream in Connect
When Connect returns, client should be guaranteed that the streamable SSE stream is connected. Fixes #583
1 parent 1a907bc commit 2541e8b

File tree

3 files changed

+111
-134
lines changed

3 files changed

+111
-134
lines changed

mcp/streamable.go

Lines changed: 100 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,7 +1346,44 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
13461346
// § 2.5: A server using the Streamable HTTP transport MAY assign a session
13471347
// ID at initialization time, by including it in an Mcp-Session-Id header
13481348
// on the HTTP response containing the InitializeResult.
1349-
go c.handleSSE("standalone SSE stream", nil, true, nil)
1349+
c.connectStandaloneSSE()
1350+
}
1351+
1352+
func (c *streamableClientConn) connectStandaloneSSE() {
1353+
resp, err := c.connectSSE("")
1354+
if err != nil {
1355+
c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err))
1356+
return
1357+
}
1358+
1359+
// [§2.2.3]: "The server MUST either return Content-Type:
1360+
// text/event-stream in response to this HTTP GET, or else return HTTP
1361+
// 405 Method Not Allowed, indicating that the server does not offer an
1362+
// SSE stream at this endpoint."
1363+
//
1364+
// [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server
1365+
if resp.StatusCode == http.StatusMethodNotAllowed {
1366+
// The server doesn't support the standalone SSE stream.
1367+
resp.Body.Close()
1368+
return
1369+
}
1370+
if resp.StatusCode == http.StatusNotFound && !c.strict {
1371+
// modelcontextprotocol/gosdk#393: some servers return NotFound instead
1372+
// of MethodNotAllowed for the standalone SSE stream.
1373+
//
1374+
// Treat this like MethodNotAllowed in non-strict mode.
1375+
if c.logger != nil {
1376+
c.logger.Warn("got 404 instead of 405 for standalone SSE stream")
1377+
}
1378+
resp.Body.Close()
1379+
return
1380+
}
1381+
summary := "standalone SSE stream"
1382+
if err := c.checkResponse(summary, resp); err != nil {
1383+
c.fail(err)
1384+
return
1385+
}
1386+
go c.handleSSE(summary, resp, true, nil)
13501387
}
13511388

13521389
// fail handles an asynchronous error while reading.
@@ -1434,22 +1471,10 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
14341471
return fmt.Errorf("%s: %v", requestSummary, err)
14351472
}
14361473

1437-
// §2.5.3: "The server MAY terminate the session at any time, after
1438-
// which it MUST respond to requests containing that session ID with HTTP
1439-
// 404 Not Found."
1440-
if resp.StatusCode == http.StatusNotFound {
1441-
// Fail the session immediately, rather than relying on jsonrpc2 to fail
1442-
// (and close) it, because we want the call to Close to know that this
1443-
// session is missing (and therefore not send the DELETE).
1444-
err := fmt.Errorf("%s: failed to send: %w", requestSummary, errSessionMissing)
1474+
if err := c.checkResponse(requestSummary, resp); err != nil {
14451475
c.fail(err)
1446-
resp.Body.Close()
14471476
return err
14481477
}
1449-
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
1450-
resp.Body.Close()
1451-
return fmt.Errorf("broken session: %v", resp.Status)
1452-
}
14531478

14541479
if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" {
14551480
c.mu.Lock()
@@ -1463,6 +1488,8 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
14631488
return fmt.Errorf("mismatching session IDs %q and %q", hadSessionID, sessionID)
14641489
}
14651490
}
1491+
// TODO(rfindley): this logic isn't quite right.
1492+
// We should keep going even if the server returns 202, if we have a call.
14661493
if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusAccepted {
14671494
// [§2.1.4]: "If the input is a JSON-RPC response or notification:
14681495
// If the server accepts the input, the server MUST return HTTP status code 202 Accepted with no body."
@@ -1543,73 +1570,63 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp
15431570
//
15441571
// If forCall is set, it is the call that initiated the stream, and the
15451572
// stream is complete when we receive its response.
1546-
func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
1547-
resp := initialResp
1548-
var lastEventID string
1573+
func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
15491574
for {
1575+
// Connection was successful. Continue the loop with the new response.
15501576
// TODO: we should set a reasonable limit on the number of times we'll try
15511577
// getting a response for a given request.
15521578
//
15531579
// Eventually, if we don't get the response, we should stop trying and
15541580
// fail the request.
1555-
if resp != nil {
1556-
eventID, clientClosed := c.processStream(requestSummary, resp, forCall)
1557-
lastEventID = eventID
1581+
lastEventID, clientClosed := c.processStream(requestSummary, resp, forCall)
15581582

1559-
// If the connection was closed by the client, we're done.
1560-
if clientClosed {
1561-
return
1562-
}
1563-
// If the stream has ended, then do not reconnect if the stream is
1564-
// temporary (POST initiated SSE).
1565-
if lastEventID == "" && !persistent {
1566-
return
1567-
}
1583+
// If the connection was closed by the client, we're done.
1584+
if clientClosed {
1585+
return
1586+
}
1587+
// If the stream has ended, then do not reconnect if the stream is
1588+
// temporary (POST initiated SSE).
1589+
if lastEventID == "" && !persistent {
1590+
return
15681591
}
15691592

15701593
// The stream was interrupted or ended by the server. Attempt to reconnect.
1571-
newResp, err := c.reconnect(lastEventID)
1594+
newResp, err := c.connectSSE(lastEventID)
15721595
if err != nil {
15731596
// All reconnection attempts failed: fail the connection.
15741597
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err))
15751598
return
15761599
}
15771600
resp = newResp
1578-
if resp.StatusCode == http.StatusMethodNotAllowed && persistent {
1579-
// [§2.2.3]: "The server MUST either return Content-Type:
1580-
// text/event-stream in response to this HTTP GET, or else return HTTP
1581-
// 405 Method Not Allowed, indicating that the server does not offer an
1582-
// SSE stream at this endpoint."
1583-
//
1584-
// [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server
1585-
1586-
// The server doesn't support the standalone SSE stream.
1587-
resp.Body.Close()
1588-
return
1589-
}
1590-
if resp.StatusCode == http.StatusNotFound && persistent && !c.strict {
1591-
// modelcontextprotocol/gosdk#393: some servers return NotFound instead
1592-
// of MethodNotAllowed for the standalone SSE stream.
1593-
//
1594-
// Treat this like MethodNotAllowed in non-strict mode.
1595-
if c.logger != nil {
1596-
c.logger.Warn("got 404 instead of 405 for standalonw SSE stream")
1597-
}
1598-
resp.Body.Close()
1599-
return
1600-
}
1601-
// (see equivalent handling in [streamableClientConn.Write]).
1602-
if resp.StatusCode == http.StatusNotFound {
1603-
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing))
1601+
if err := c.checkResponse(requestSummary, resp); err != nil {
1602+
c.fail(err)
16041603
return
16051604
}
1606-
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
1605+
}
1606+
}
1607+
1608+
// checkResponse checks the status code of the provided response, and
1609+
// translates it into an error if the request was unsuccessful.
1610+
//
1611+
// The response body is close if a non-nil error is returned.
1612+
func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.Response) (err error) {
1613+
defer func() {
1614+
if err != nil {
16071615
resp.Body.Close()
1608-
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode)))
1609-
return
16101616
}
1611-
// Reconnection was successful. Continue the loop with the new response.
1617+
}()
1618+
// §2.5.3: "The server MAY terminate the session at any time, after
1619+
// which it MUST respond to requests containing that session ID with HTTP
1620+
// 404 Not Found."
1621+
if resp.StatusCode == http.StatusNotFound {
1622+
// Return an errSessionMissing to avoid sending a redundant DELETE when the
1623+
// session is already gone.
1624+
return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing)
1625+
}
1626+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
1627+
return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode))
16121628
}
1629+
return nil
16131630
}
16141631

16151632
// processStream reads from a single response body, sending events to the
@@ -1620,6 +1637,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
16201637
defer resp.Body.Close()
16211638
for evt, err := range scanEvents(resp.Body) {
16221639
if err != nil {
1640+
// TODO: we should differentiate EOF from other errors here.
16231641
break
16241642
}
16251643

@@ -1664,39 +1682,48 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
16641682
return lastEventID, false
16651683
}
16661684

1667-
// reconnect handles the logic of retrying a connection with an exponential
1668-
// backoff strategy. It returns a new, valid HTTP response if successful, or
1669-
// an error if all retries are exhausted.
1670-
func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) {
1685+
// connectSSE handles the logic of connecting a text/event-stream connection.
1686+
//
1687+
// If lastEventID is set, it is the last-event ID of a stream being resumed.
1688+
//
1689+
// If connection fails, connectSSE retries with an exponential backoff
1690+
// strategy. It returns a new, valid HTTP response if successful, or an error
1691+
// if all retries are exhausted.
1692+
func (c *streamableClientConn) connectSSE(lastEventID string) (*http.Response, error) {
16711693
var finalErr error
1672-
1673-
// We can reach the 'reconnect' path through the standlone SSE request, in which case
1674-
// lastEventID will be "".
1675-
//
1676-
// In this case, we need an initial attempt.
1694+
// If lastEventID is set, we've already connected successfully once, so
1695+
// consider that to be the first attempt.
16771696
attempt := 0
16781697
if lastEventID != "" {
16791698
attempt = 1
16801699
}
1681-
16821700
for ; attempt <= c.maxRetries; attempt++ {
16831701
select {
16841702
case <-c.done:
16851703
return nil, fmt.Errorf("connection closed by client during reconnect")
16861704
case <-time.After(calculateReconnectDelay(attempt)):
1687-
resp, err := c.establishSSE(lastEventID)
1705+
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil)
1706+
if err != nil {
1707+
return nil, err
1708+
}
1709+
c.setMCPHeaders(req)
1710+
if lastEventID != "" {
1711+
req.Header.Set("Last-Event-ID", lastEventID)
1712+
}
1713+
req.Header.Set("Accept", "text/event-stream")
1714+
resp, err := c.client.Do(req)
16881715
if err != nil {
16891716
finalErr = err // Store the error and try again.
16901717
continue
16911718
}
16921719
return resp, nil
16931720
}
16941721
}
1695-
// If the loop completes, all retries have failed.
1722+
// If the loop completes, all retries have failed, or the client is closing.
16961723
if finalErr != nil {
16971724
return nil, fmt.Errorf("connection failed after %d attempts: %w", c.maxRetries, finalErr)
16981725
}
1699-
return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries)
1726+
return nil, fmt.Errorf("connection aborted after %d attempts", c.maxRetries)
17001727
}
17011728

17021729
// Close implements the [Connection] interface.
@@ -1723,23 +1750,6 @@ func (c *streamableClientConn) Close() error {
17231750
return c.closeErr
17241751
}
17251752

1726-
// establishSSE establishes the persistent SSE listening stream.
1727-
// It is used for reconnect attempts using the Last-Event-ID header to
1728-
// resume a broken stream where it left off.
1729-
func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) {
1730-
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil)
1731-
if err != nil {
1732-
return nil, err
1733-
}
1734-
c.setMCPHeaders(req)
1735-
if lastEventID != "" {
1736-
req.Header.Set("Last-Event-ID", lastEventID)
1737-
}
1738-
req.Header.Set("Accept", "text/event-stream")
1739-
1740-
return c.client.Do(req)
1741-
}
1742-
17431753
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
17441754
func calculateReconnectDelay(attempt int) time.Duration {
17451755
if attempt == 0 {

mcp/streamable_client_test.go

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,6 @@ func TestStreamableClientGETHandling(t *testing.T) {
268268
status: test.status,
269269
wantProtocolVersion: latestProtocolVersion,
270270
},
271-
{"POST", "123", methodListTools}: {
272-
header: header{
273-
"Content-Type": "application/json",
274-
sessionIDHeader: "123",
275-
},
276-
body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)),
277-
optional: true,
278-
},
279271
{"DELETE", "123", ""}: {optional: true},
280272
},
281273
}
@@ -285,36 +277,18 @@ func TestStreamableClientGETHandling(t *testing.T) {
285277
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
286278
client := NewClient(testImpl, nil)
287279
session, err := client.Connect(ctx, transport, nil)
288-
if err != nil {
289-
t.Fatalf("client.Connect() failed: %v", err)
280+
if err == nil {
281+
defer session.Close()
290282
}
291-
292-
// Since we need the client to observe the result of the hanging GET,
293-
// wait for all requests to be handled.
294-
start := time.Now()
295-
delay := 1 * time.Millisecond
296-
for range 10 {
297-
if len(fake.missingRequests()) == 0 {
298-
break
283+
if test.wantErrorContaining != "" {
284+
if err == nil {
285+
t.Fatalf("Connect succeeded unexpectedly, want error containing %q", test.wantErrorContaining)
299286
}
300-
time.Sleep(delay)
301-
delay *= 2
302-
}
303-
if missing := fake.missingRequests(); len(missing) > 0 {
304-
t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing)
305-
}
306-
307-
_, err = session.ListTools(ctx, nil)
308-
if (err != nil) != (test.wantErrorContaining != "") {
309-
t.Errorf("After initialization, got error %v, want containing %q", err, test.wantErrorContaining)
310-
} else if err != nil {
311-
if !strings.Contains(err.Error(), test.wantErrorContaining) {
312-
t.Errorf("After initialization, got error %s, want containing %q", err, test.wantErrorContaining)
287+
if got := err.Error(); !strings.Contains(got, test.wantErrorContaining) {
288+
t.Errorf("Connect error = %q, want containing %q", got, test.wantErrorContaining)
313289
}
314-
}
315-
316-
if err := session.Close(); err != nil {
317-
t.Errorf("closing session: %v", err)
290+
} else if err != nil {
291+
t.Fatalf("Connect failed: %v", err)
318292
}
319293
})
320294
}
@@ -334,7 +308,7 @@ func TestStreamableClientStrictness(t *testing.T) {
334308
{"conformant server", true, http.StatusAccepted, http.StatusMethodNotAllowed, false, false},
335309
{"strict initialized", true, http.StatusOK, http.StatusMethodNotAllowed, true, false},
336310
{"unstrict initialized", false, http.StatusOK, http.StatusMethodNotAllowed, false, false},
337-
{"strict GET", true, http.StatusAccepted, http.StatusNotFound, false, true},
311+
{"strict GET", true, http.StatusAccepted, http.StatusNotFound, true, false},
338312
{"unstrict GET", false, http.StatusOK, http.StatusNotFound, false, false},
339313
}
340314
for _, test := range tests {

mcp/streamable_test.go

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -594,14 +594,7 @@ func TestServerInitiatedSSE(t *testing.T) {
594594
notifications := make(chan string)
595595
server := NewServer(testImpl, nil)
596596

597-
opts := &StreamableHTTPOptions{
598-
// TODO(#583): for now, this is required for guaranteed message delivery.
599-
// However, it shouldn't be necessary to use replay here, as we should be
600-
// guaranteed that the standalone SSE stream is started by the time the
601-
// client is connected.
602-
EventStore: NewMemoryEventStore(nil),
603-
}
604-
httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, opts)))
597+
httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)))
605598
defer httpServer.Close()
606599

607600
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)

0 commit comments

Comments
 (0)