Skip to content

Commit a911cd0

Browse files
mcp/streamable: add resumability for the Streamable transport (#133)
This CL implements a retry mechanism to resume SSE streams to recover from network failures. For #10
1 parent 1465442 commit a911cd0

File tree

2 files changed

+318
-26
lines changed

2 files changed

+318
-26
lines changed

mcp/streamable.go

Lines changed: 193 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ import (
99
"context"
1010
"fmt"
1111
"io"
12+
"math"
13+
"math/rand/v2"
1214
"net/http"
1315
"strconv"
1416
"strings"
1517
"sync"
1618
"sync/atomic"
19+
"time"
1720

1821
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1922
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
@@ -597,12 +600,39 @@ type StreamableClientTransport struct {
597600
opts StreamableClientTransportOptions
598601
}
599602

603+
// StreamableReconnectOptions defines parameters for client reconnect attempts.
604+
type StreamableReconnectOptions struct {
605+
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
606+
// A value of 0 or less means never retry.
607+
MaxRetries int
608+
609+
// growFactor is the multiplicative factor by which the delay increases after each attempt.
610+
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
611+
// It must be 1.0 or greater if MaxRetries is greater than 0.
612+
growFactor float64
613+
614+
// initialDelay is the base delay for the first reconnect attempt.
615+
initialDelay time.Duration
616+
617+
// maxDelay caps the backoff delay, preventing it from growing indefinitely.
618+
maxDelay time.Duration
619+
}
620+
621+
// DefaultReconnectOptions provides sensible defaults for reconnect logic.
622+
var DefaultReconnectOptions = &StreamableReconnectOptions{
623+
MaxRetries: 5,
624+
growFactor: 1.5,
625+
initialDelay: 1 * time.Second,
626+
maxDelay: 30 * time.Second,
627+
}
628+
600629
// StreamableClientTransportOptions provides options for the
601630
// [NewStreamableClientTransport] constructor.
602631
type StreamableClientTransportOptions struct {
603632
// HTTPClient is the client to use for making HTTP requests. If nil,
604633
// http.DefaultClient is used.
605-
HTTPClient *http.Client
634+
HTTPClient *http.Client
635+
ReconnectOptions *StreamableReconnectOptions
606636
}
607637

608638
// NewStreamableClientTransport returns a new client transport that connects to
@@ -628,22 +658,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
628658
if client == nil {
629659
client = http.DefaultClient
630660
}
631-
return &streamableClientConn{
632-
url: t.url,
633-
client: client,
634-
incoming: make(chan []byte, 100),
635-
done: make(chan struct{}),
636-
}, nil
661+
reconnOpts := t.opts.ReconnectOptions
662+
if reconnOpts == nil {
663+
reconnOpts = DefaultReconnectOptions
664+
}
665+
// Create a new cancellable context that will manage the connection's lifecycle.
666+
// This is crucial for cleanly shutting down the background SSE listener by
667+
// cancelling its blocking network operations, which prevents hangs on exit.
668+
connCtx, cancel := context.WithCancel(context.Background())
669+
conn := &streamableClientConn{
670+
url: t.url,
671+
client: client,
672+
incoming: make(chan []byte, 100),
673+
done: make(chan struct{}),
674+
ReconnectOptions: reconnOpts,
675+
ctx: connCtx,
676+
cancel: cancel,
677+
}
678+
return conn, nil
637679
}
638680

639681
type streamableClientConn struct {
640-
url string
641-
client *http.Client
642-
incoming chan []byte
643-
done chan struct{}
682+
url string
683+
client *http.Client
684+
incoming chan []byte
685+
done chan struct{}
686+
ReconnectOptions *StreamableReconnectOptions
644687

645688
closeOnce sync.Once
646689
closeErr error
690+
ctx context.Context
691+
cancel context.CancelFunc
647692

648693
mu sync.Mutex
649694
protocolVersion string
@@ -665,6 +710,12 @@ func (c *streamableClientConn) SessionID() string {
665710

666711
// Read implements the [Connection] interface.
667712
func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) {
713+
s.mu.Lock()
714+
err := s.err
715+
s.mu.Unlock()
716+
if err != nil {
717+
return nil, err
718+
}
668719
select {
669720
case <-ctx.Done():
670721
return nil, ctx.Err()
@@ -745,6 +796,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
745796
sessionID = resp.Header.Get(sessionIDHeader)
746797
switch ct := resp.Header.Get("Content-Type"); ct {
747798
case "text/event-stream":
799+
// Section 2.1: The SSE stream is initiated after a POST.
748800
go s.handleSSE(resp)
749801
case "application/json":
750802
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
@@ -757,34 +809,115 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
757809
return sessionID, nil
758810
}
759811

760-
func (s *streamableClientConn) handleSSE(resp *http.Response) {
812+
// handleSSE manages the entire lifecycle of an SSE connection. It processes
813+
// an incoming Server-Sent Events stream and automatically handles reconnection
814+
// logic if the stream breaks.
815+
func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
816+
resp := initialResp
817+
var lastEventID string
818+
819+
for {
820+
eventID, clientClosed := s.processStream(resp)
821+
lastEventID = eventID
822+
823+
// If the connection was closed by the client, we're done.
824+
if clientClosed {
825+
return
826+
}
827+
828+
// The stream was interrupted or ended by the server. Attempt to reconnect.
829+
newResp, err := s.reconnect(lastEventID)
830+
if err != nil {
831+
// All reconnection attempts failed. Set the final error, close the
832+
// connection, and exit the goroutine.
833+
s.mu.Lock()
834+
s.err = err
835+
s.mu.Unlock()
836+
s.Close()
837+
return
838+
}
839+
840+
// Reconnection was successful. Continue the loop with the new response.
841+
resp = newResp
842+
}
843+
}
844+
845+
// processStream reads from a single response body, sending events to the
846+
// incoming channel. It returns the ID of the last processed event, any error
847+
// that occurred, and a flag indicating if the connection was closed by the client.
848+
func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) {
761849
defer resp.Body.Close()
762850

763-
done := make(chan struct{})
764-
go func() {
765-
defer close(done)
766-
for evt, err := range scanEvents(resp.Body) {
851+
for evt, err := range scanEvents(resp.Body) {
852+
if err != nil {
853+
return lastEventID, false
854+
}
855+
856+
if evt.ID != "" {
857+
lastEventID = evt.ID
858+
}
859+
860+
select {
861+
case s.incoming <- evt.Data:
862+
case <-s.done:
863+
// The connection was closed by the client; exit gracefully.
864+
return lastEventID, true
865+
}
866+
}
867+
868+
// The loop finished without an error, indicating the server closed the stream.
869+
// We'll attempt to reconnect, so this is not a client-side close.
870+
return lastEventID, false
871+
}
872+
873+
// reconnect handles the logic of retrying a connection with an exponential
874+
// backoff strategy. It returns a new, valid HTTP response if successful, or
875+
// an error if all retries are exhausted.
876+
func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) {
877+
var finalErr error
878+
879+
for attempt := 0; attempt < s.ReconnectOptions.MaxRetries; attempt++ {
880+
select {
881+
case <-s.done:
882+
return nil, fmt.Errorf("connection closed by client during reconnect")
883+
case <-time.After(calculateReconnectDelay(s.ReconnectOptions, attempt)):
884+
resp, err := s.establishSSE(lastEventID)
767885
if err != nil {
768-
// TODO: surface this error; possibly break the stream
769-
return
886+
finalErr = err // Store the error and try again.
887+
continue
770888
}
771-
select {
772-
case <-s.done:
773-
return
774-
case s.incoming <- evt.Data:
889+
890+
if !isResumable(resp) {
891+
// The server indicated we should not continue.
892+
resp.Body.Close()
893+
return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status)
775894
}
895+
896+
return resp, nil
776897
}
777-
}()
898+
}
899+
// If the loop completes, all retries have failed.
900+
if finalErr != nil {
901+
return nil, fmt.Errorf("connection failed after %d attempts: %w", s.ReconnectOptions.MaxRetries, finalErr)
902+
}
903+
return nil, fmt.Errorf("connection failed after %d attempts", s.ReconnectOptions.MaxRetries)
904+
}
778905

779-
select {
780-
case <-s.done:
781-
case <-done:
906+
// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed.
907+
func isResumable(resp *http.Response) bool {
908+
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
909+
if resp.StatusCode == http.StatusMethodNotAllowed {
910+
return false
782911
}
912+
913+
return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
783914
}
784915

785916
// Close implements the [Connection] interface.
786917
func (s *streamableClientConn) Close() error {
787918
s.closeOnce.Do(func() {
919+
// Cancel any hanging network requests.
920+
s.cancel()
788921
close(s.done)
789922

790923
req, err := http.NewRequest(http.MethodDelete, s.url, nil)
@@ -803,3 +936,37 @@ func (s *streamableClientConn) Close() error {
803936
})
804937
return s.closeErr
805938
}
939+
940+
// establishSSE establishes the persistent SSE listening stream.
941+
// It is used for reconnect attempts using the Last-Event-ID header to
942+
// resume a broken stream where it left off.
943+
func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) {
944+
req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil)
945+
if err != nil {
946+
return nil, err
947+
}
948+
s.mu.Lock()
949+
if s._sessionID != "" {
950+
req.Header.Set("Mcp-Session-Id", s._sessionID)
951+
}
952+
s.mu.Unlock()
953+
if lastEventID != "" {
954+
req.Header.Set("Last-Event-ID", lastEventID)
955+
}
956+
req.Header.Set("Accept", "text/event-stream")
957+
958+
return s.client.Do(req)
959+
}
960+
961+
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
962+
func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration {
963+
// Calculate the exponential backoff using the grow factor.
964+
backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt)))
965+
// Cap the backoffDuration at maxDelay.
966+
backoffDuration = min(backoffDuration, opts.maxDelay)
967+
968+
// Use a full jitter using backoffDuration
969+
jitter := rand.N(backoffDuration)
970+
971+
return backoffDuration + jitter
972+
}

0 commit comments

Comments
 (0)