diff --git a/mcp/streamable.go b/mcp/streamable.go index d8ce45e7..190b5828 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1346,7 +1346,44 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // § 2.5: A server using the Streamable HTTP transport MAY assign a session // ID at initialization time, by including it in an Mcp-Session-Id header // on the HTTP response containing the InitializeResult. - go c.handleSSE("standalone SSE stream", nil, true, nil) + c.connectStandaloneSSE() +} + +func (c *streamableClientConn) connectStandaloneSSE() { + resp, err := c.connectSSE("") + if err != nil { + c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) + return + } + + // [§2.2.3]: "The server MUST either return Content-Type: + // text/event-stream in response to this HTTP GET, or else return HTTP + // 405 Method Not Allowed, indicating that the server does not offer an + // SSE stream at this endpoint." + // + // [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server + if resp.StatusCode == http.StatusMethodNotAllowed { + // The server doesn't support the standalone SSE stream. + resp.Body.Close() + return + } + if resp.StatusCode == http.StatusNotFound && !c.strict { + // modelcontextprotocol/gosdk#393: some servers return NotFound instead + // of MethodNotAllowed for the standalone SSE stream. + // + // Treat this like MethodNotAllowed in non-strict mode. + if c.logger != nil { + c.logger.Warn("got 404 instead of 405 for standalone SSE stream") + } + resp.Body.Close() + return + } + summary := "standalone SSE stream" + if err := c.checkResponse(summary, resp); err != nil { + c.fail(err) + return + } + go c.handleSSE(summary, resp, true, nil) } // fail handles an asynchronous error while reading. @@ -1434,22 +1471,10 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return fmt.Errorf("%s: %v", requestSummary, err) } - // §2.5.3: "The server MAY terminate the session at any time, after - // which it MUST respond to requests containing that session ID with HTTP - // 404 Not Found." - if resp.StatusCode == http.StatusNotFound { - // Fail the session immediately, rather than relying on jsonrpc2 to fail - // (and close) it, because we want the call to Close to know that this - // session is missing (and therefore not send the DELETE). - err := fmt.Errorf("%s: failed to send: %w", requestSummary, errSessionMissing) + if err := c.checkResponse(requestSummary, resp); err != nil { c.fail(err) - resp.Body.Close() return err } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - resp.Body.Close() - return fmt.Errorf("broken session: %v", resp.Status) - } if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { c.mu.Lock() @@ -1463,6 +1488,8 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return fmt.Errorf("mismatching session IDs %q and %q", hadSessionID, sessionID) } } + // TODO(rfindley): this logic isn't quite right. + // We should keep going even if the server returns 202, if we have a call. if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusAccepted { // [§2.1.4]: "If the input is a JSON-RPC response or notification: // If the server accepts the input, the server MUST return HTTP status code 202 Accepted with no body." @@ -1543,73 +1570,63 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp // // If forCall is set, it is the call that initiated the stream, and the // stream is complete when we receive its response. -func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forCall *jsonrpc2.Request) { - resp := initialResp - var lastEventID string +func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) { for { + // Connection was successful. Continue the loop with the new response. // TODO: we should set a reasonable limit on the number of times we'll try // getting a response for a given request. // // Eventually, if we don't get the response, we should stop trying and // fail the request. - if resp != nil { - eventID, clientClosed := c.processStream(requestSummary, resp, forCall) - lastEventID = eventID + lastEventID, clientClosed := c.processStream(requestSummary, resp, forCall) - // If the connection was closed by the client, we're done. - if clientClosed { - return - } - // If the stream has ended, then do not reconnect if the stream is - // temporary (POST initiated SSE). - if lastEventID == "" && !persistent { - return - } + // If the connection was closed by the client, we're done. + if clientClosed { + return + } + // If the stream has ended, then do not reconnect if the stream is + // temporary (POST initiated SSE). + if lastEventID == "" && !persistent { + return } // The stream was interrupted or ended by the server. Attempt to reconnect. - newResp, err := c.reconnect(lastEventID) + newResp, err := c.connectSSE(lastEventID) if err != nil { // All reconnection attempts failed: fail the connection. c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) return } resp = newResp - if resp.StatusCode == http.StatusMethodNotAllowed && persistent { - // [§2.2.3]: "The server MUST either return Content-Type: - // text/event-stream in response to this HTTP GET, or else return HTTP - // 405 Method Not Allowed, indicating that the server does not offer an - // SSE stream at this endpoint." - // - // [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server - - // The server doesn't support the standalone SSE stream. - resp.Body.Close() - return - } - if resp.StatusCode == http.StatusNotFound && persistent && !c.strict { - // modelcontextprotocol/gosdk#393: some servers return NotFound instead - // of MethodNotAllowed for the standalone SSE stream. - // - // Treat this like MethodNotAllowed in non-strict mode. - if c.logger != nil { - c.logger.Warn("got 404 instead of 405 for standalonw SSE stream") - } - resp.Body.Close() - return - } - // (see equivalent handling in [streamableClientConn.Write]). - if resp.StatusCode == http.StatusNotFound { - c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing)) + if err := c.checkResponse(requestSummary, resp); err != nil { + c.fail(err) return } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { + } +} + +// checkResponse checks the status code of the provided response, and +// translates it into an error if the request was unsuccessful. +// +// The response body is close if a non-nil error is returned. +func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.Response) (err error) { + defer func() { + if err != nil { resp.Body.Close() - c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode))) - return } - // Reconnection was successful. Continue the loop with the new response. + }() + // §2.5.3: "The server MAY terminate the session at any time, after + // which it MUST respond to requests containing that session ID with HTTP + // 404 Not Found." + if resp.StatusCode == http.StatusNotFound { + // Return an errSessionMissing to avoid sending a redundant DELETE when the + // session is already gone. + return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode)) } + return nil } // processStream reads from a single response body, sending events to the @@ -1620,6 +1637,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { + // TODO: we should differentiate EOF from other errors here. break } @@ -1664,27 +1682,36 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R return lastEventID, false } -// reconnect handles the logic of retrying a connection with an exponential -// backoff strategy. It returns a new, valid HTTP response if successful, or -// an error if all retries are exhausted. -func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { +// connectSSE handles the logic of connecting a text/event-stream connection. +// +// If lastEventID is set, it is the last-event ID of a stream being resumed. +// +// If connection fails, connectSSE retries with an exponential backoff +// strategy. It returns a new, valid HTTP response if successful, or an error +// if all retries are exhausted. +func (c *streamableClientConn) connectSSE(lastEventID string) (*http.Response, error) { var finalErr error - - // We can reach the 'reconnect' path through the standlone SSE request, in which case - // lastEventID will be "". - // - // In this case, we need an initial attempt. + // If lastEventID is set, we've already connected successfully once, so + // consider that to be the first attempt. attempt := 0 if lastEventID != "" { attempt = 1 } - for ; attempt <= c.maxRetries; attempt++ { select { case <-c.done: return nil, fmt.Errorf("connection closed by client during reconnect") case <-time.After(calculateReconnectDelay(attempt)): - resp, err := c.establishSSE(lastEventID) + req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil) + if err != nil { + return nil, err + } + c.setMCPHeaders(req) + if lastEventID != "" { + req.Header.Set("Last-Event-ID", lastEventID) + } + req.Header.Set("Accept", "text/event-stream") + resp, err := c.client.Do(req) if err != nil { finalErr = err // Store the error and try again. continue @@ -1692,11 +1719,11 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er return resp, nil } } - // If the loop completes, all retries have failed. + // If the loop completes, all retries have failed, or the client is closing. if finalErr != nil { return nil, fmt.Errorf("connection failed after %d attempts: %w", c.maxRetries, finalErr) } - return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries) + return nil, fmt.Errorf("connection aborted after %d attempts", c.maxRetries) } // Close implements the [Connection] interface. @@ -1723,23 +1750,6 @@ func (c *streamableClientConn) Close() error { return c.closeErr } -// establishSSE establishes the persistent SSE listening stream. -// It is used for reconnect attempts using the Last-Event-ID header to -// resume a broken stream where it left off. -func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) { - req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil) - if err != nil { - return nil, err - } - c.setMCPHeaders(req) - if lastEventID != "" { - req.Header.Set("Last-Event-ID", lastEventID) - } - req.Header.Set("Accept", "text/event-stream") - - return c.client.Do(req) -} - // calculateReconnectDelay calculates a delay using exponential backoff with full jitter. func calculateReconnectDelay(attempt int) time.Duration { if attempt == 0 { diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 4a4f5c65..42472c5e 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -13,7 +13,6 @@ import ( "strings" "sync" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" @@ -156,7 +155,6 @@ func TestStreamableClientTransportLifecycle(t *testing.T) { header: header{ "Content-Type": "text/event-stream", }, - optional: true, wantProtocolVersion: latestProtocolVersion, }, {"DELETE", "123", ""}: {}, @@ -205,8 +203,7 @@ func TestStreamableClientRedundantDelete(t *testing.T) { wantProtocolVersion: latestProtocolVersion, }, {"GET", "123", ""}: { - status: http.StatusMethodNotAllowed, - optional: true, + status: http.StatusMethodNotAllowed, }, {"POST", "123", methodListTools}: { status: http.StatusNotFound, @@ -268,14 +265,6 @@ func TestStreamableClientGETHandling(t *testing.T) { status: test.status, wantProtocolVersion: latestProtocolVersion, }, - {"POST", "123", methodListTools}: { - header: header{ - "Content-Type": "application/json", - sessionIDHeader: "123", - }, - body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)), - optional: true, - }, {"DELETE", "123", ""}: {optional: true}, }, } @@ -285,36 +274,18 @@ func TestStreamableClientGETHandling(t *testing.T) { transport := &StreamableClientTransport{Endpoint: httpServer.URL} client := NewClient(testImpl, nil) session, err := client.Connect(ctx, transport, nil) - if err != nil { - t.Fatalf("client.Connect() failed: %v", err) + if err == nil { + defer session.Close() } - - // Since we need the client to observe the result of the hanging GET, - // wait for all requests to be handled. - start := time.Now() - delay := 1 * time.Millisecond - for range 10 { - if len(fake.missingRequests()) == 0 { - break + if test.wantErrorContaining != "" { + if err == nil { + t.Fatalf("Connect succeeded unexpectedly, want error containing %q", test.wantErrorContaining) } - time.Sleep(delay) - delay *= 2 - } - if missing := fake.missingRequests(); len(missing) > 0 { - t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing) - } - - _, err = session.ListTools(ctx, nil) - if (err != nil) != (test.wantErrorContaining != "") { - t.Errorf("After initialization, got error %v, want containing %q", err, test.wantErrorContaining) - } else if err != nil { - if !strings.Contains(err.Error(), test.wantErrorContaining) { - t.Errorf("After initialization, got error %s, want containing %q", err, test.wantErrorContaining) + if got := err.Error(); !strings.Contains(got, test.wantErrorContaining) { + t.Errorf("Connect error = %q, want containing %q", got, test.wantErrorContaining) } - } - - if err := session.Close(); err != nil { - t.Errorf("closing session: %v", err) + } else if err != nil { + t.Fatalf("Connect failed: %v", err) } }) } @@ -329,13 +300,12 @@ func TestStreamableClientStrictness(t *testing.T) { initializedStatus int getStatus int wantConnectError bool - wantListError bool }{ - {"conformant server", true, http.StatusAccepted, http.StatusMethodNotAllowed, false, false}, - {"strict initialized", true, http.StatusOK, http.StatusMethodNotAllowed, true, false}, - {"unstrict initialized", false, http.StatusOK, http.StatusMethodNotAllowed, false, false}, - {"strict GET", true, http.StatusAccepted, http.StatusNotFound, false, true}, - {"unstrict GET", false, http.StatusOK, http.StatusNotFound, false, false}, + {"conformant server", true, http.StatusAccepted, http.StatusMethodNotAllowed, false}, + {"strict initialized", true, http.StatusOK, http.StatusMethodNotAllowed, true}, + {"unstrict initialized", false, http.StatusOK, http.StatusMethodNotAllowed, false}, + {"strict GET", true, http.StatusAccepted, http.StatusNotFound, true}, + {"unstrict GET", false, http.StatusOK, http.StatusNotFound, false}, } for _, test := range tests { t.Run(test.label, func(t *testing.T) { @@ -383,23 +353,9 @@ func TestStreamableClientStrictness(t *testing.T) { if err != nil { return } - // Since we need the client to observe the result of the hanging GET, - // wait for all requests to be handled. - start := time.Now() - delay := 1 * time.Millisecond - for range 10 { - if len(fake.missingRequests()) == 0 { - break - } - time.Sleep(delay) - delay *= 2 - } - if missing := fake.missingRequests(); len(missing) > 0 { - t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing) - } _, err = session.ListTools(ctx, nil) - if (err != nil) != test.wantListError { - t.Errorf("ListTools returned error %v; want error: %t", err, test.wantListError) + if err != nil { + t.Errorf("ListTools failed: %v", err) } if err := session.Close(); err != nil { t.Errorf("closing session: %v", err) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 0f38a0f4..bd5d2c57 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -594,14 +594,7 @@ func TestServerInitiatedSSE(t *testing.T) { notifications := make(chan string) server := NewServer(testImpl, nil) - opts := &StreamableHTTPOptions{ - // TODO(#583): for now, this is required for guaranteed message delivery. - // However, it shouldn't be necessary to use replay here, as we should be - // guaranteed that the standalone SSE stream is started by the time the - // client is connected. - EventStore: NewMemoryEventStore(nil), - } - httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, opts))) + httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))) defer httpServer.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)