Skip to content

Commit 8254fbc

Browse files
committed
mcp: fix reconnect semantics for hanging GET
A few problems with reconnection cropped up in the review of PR #307. We should allow for the hanging GET to fail with StatusMethodNotAllowed. This simply means that the server does not support sending notifications or requests over the GET, which is allowed in the spec. Also, we should fix the initial delay of the hanging GET request: it should start with 0 delay. Fix the math for this and subsequent attempts. Incidentally, this makes the tests take 3s on my machine, down from 9s. Also address some comments from #307.
1 parent cdd1d9b commit 8254fbc

File tree

2 files changed

+42
-28
lines changed

2 files changed

+42
-28
lines changed

internal/jsonrpc2/conn.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,8 +728,8 @@ func (c *Connection) write(ctx context.Context, msg Message) error {
728728
var err error
729729
// Fail writes immediately if the connection is shutting down.
730730
//
731-
// TODO(rfindley): should we allow cancellation notifations through? It could
732-
// be the case that writes can still succeed.
731+
// TODO(rfindley): should we allow cancellation notifications through? It
732+
// could be the case that writes can still succeed.
733733
c.updateInFlight(func(s *inFlightState) {
734734
err = s.shuttingDown(ErrServerClosing)
735735
})

mcp/streamable.go

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ type StreamableHTTPOptions struct {
4949
// GetSessionID provides the next session ID to use for an incoming request.
5050
// If nil, a default randomly generated ID will be used.
5151
//
52+
// Session IDs should be globally unique across the scope of the server,
53+
// which may span multiple processes in the case of distributed servers.
54+
//
5255
// As a special case, if GetSessionID returns the empty string, the
5356
// Mcp-Session-Id header will not be set.
5457
GetSessionID func() string
@@ -58,7 +61,9 @@ type StreamableHTTPOptions struct {
5861
// A stateless server does not validate the Mcp-Session-Id header, and uses a
5962
// temporary session with default initialization parameters. Any
6063
// server->client request is rejected immediately as there's no way for the
61-
// client to respond.
64+
// client to respond. Server->Client notifications may reach the client if
65+
// they are made in the context of an incoming request, as described in the
66+
// documentation for [StreamableServerTransport].
6267
Stateless bool
6368

6469
// TODO: support session retention (?)
@@ -136,9 +141,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
136141
transport, _ = h.transports[sessionID]
137142
h.mu.Unlock()
138143
if transport == nil && !h.opts.Stateless {
139-
// In stateless mode we allow a missing transport.
144+
// Unless we're in 'stateless' mode, which doesn't perform any Session-ID
145+
// validation, we require that the session ID matches a known session.
140146
//
141-
// A synthetic transport will be created below for the transient session.
147+
// In stateless mode, a temporary transport is be created below.
142148
http.Error(w, "session not found", http.StatusNotFound)
143149
return
144150
}
@@ -201,7 +207,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
201207
// stateless servers.
202208
body, err := io.ReadAll(req.Body)
203209
if err != nil {
204-
http.Error(w, "failed to read body", http.StatusBadRequest)
210+
http.Error(w, "failed to read body", http.StatusInternalServerError)
205211
return
206212
}
207213
req.Body.Close()
@@ -272,9 +278,22 @@ type StreamableServerTransportOptions struct {
272278
// A StreamableServerTransport implements the server side of the MCP streamable
273279
// transport.
274280
//
275-
// Each StreamableServerTransport may be connected (via [Server.Connect]) at
281+
// Each StreamableServerTransport must be connected (via [Server.Connect]) at
276282
// most once, since [StreamableServerTransport.ServeHTTP] serves messages to
277283
// the connected session.
284+
//
285+
// Reads from the streamable server connection receive messages from http POST
286+
// requests from the client. Writes to the streamable server connection are
287+
// sent either to the hanging POST response, or to the hanging GET, according
288+
// to the following rules:
289+
// - JSON-RPC responses to incoming requests are always routed to the
290+
// appropriate HTTP response.
291+
// - Requests or notifications made with a context.Context value derived from
292+
// an incoming request handler, are routed to the HTTP response
293+
// corresponding to that request, unless it has already terminated, in
294+
// which case they are routed to the hanging GET.
295+
// - Requests or notifications made with a detached context.Context value are
296+
// routed to the hanging GET.
278297
type StreamableServerTransport struct {
279298
// SessionID is the ID of this session.
280299
//
@@ -285,7 +304,7 @@ type StreamableServerTransport struct {
285304
// generator to produce one, as with [crypto/rand.Text].)
286305
SessionID string
287306

288-
// Stateless controls whether the eventstore is 'Stateless'. Servers sessions
307+
// Stateless controls whether the eventstore is 'Stateless'. Server sessions
289308
// connected to a stateless transport are disallowed from making outgoing
290309
// requests.
291310
//
@@ -1228,9 +1247,18 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
12281247
c.fail(err)
12291248
return
12301249
}
1231-
1232-
// Reconnection was successful. Continue the loop with the new response.
12331250
resp = newResp
1251+
if resp.StatusCode == http.StatusMethodNotAllowed && persistent {
1252+
// The server doesn't support the hanging GET.
1253+
resp.Body.Close()
1254+
return
1255+
}
1256+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
1257+
resp.Body.Close()
1258+
c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode)))
1259+
return
1260+
}
1261+
// Reconnection was successful. Continue the loop with the new response.
12341262
}
12351263
}
12361264

@@ -1298,13 +1326,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
12981326
finalErr = err // Store the error and try again.
12991327
continue
13001328
}
1301-
1302-
if !isResumable(resp) {
1303-
// The server indicated we should not continue.
1304-
resp.Body.Close()
1305-
return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status)
1306-
}
1307-
13081329
return resp, nil
13091330
}
13101331
}
@@ -1315,16 +1336,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
13151336
return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries)
13161337
}
13171338

1318-
// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed.
1319-
func isResumable(resp *http.Response) bool {
1320-
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
1321-
if resp.StatusCode == http.StatusMethodNotAllowed {
1322-
return false
1323-
}
1324-
1325-
return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
1326-
}
1327-
13281339
// Close implements the [Connection] interface.
13291340
func (c *streamableClientConn) Close() error {
13301341
c.closeOnce.Do(func() {
@@ -1364,8 +1375,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
13641375

13651376
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
13661377
func calculateReconnectDelay(attempt int) time.Duration {
1378+
if attempt == 0 {
1379+
return 0
1380+
}
13671381
// Calculate the exponential backoff using the grow factor.
1368-
backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt)))
1382+
backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt-1)))
13691383
// Cap the backoffDuration at maxDelay.
13701384
backoffDuration = min(backoffDuration, reconnectMaxDelay)
13711385

0 commit comments

Comments
 (0)