Skip to content

Commit 881274b

Browse files
committed
mcp/streamable: add resumability for the Streamable transport
This CL implements a retry mechanism to resume SSE streams to recover from network failures.
1 parent de4b788 commit 881274b

File tree

2 files changed

+312
-14
lines changed

2 files changed

+312
-14
lines changed

mcp/streamable.go

Lines changed: 166 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ import (
99
"context"
1010
"fmt"
1111
"io"
12+
"math"
13+
"math/rand/v2"
1214
"net/http"
1315
"strconv"
1416
"strings"
1517
"sync"
1618
"sync/atomic"
19+
"time"
1720

1821
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1922
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
@@ -594,12 +597,39 @@ type StreamableClientTransport struct {
594597
opts StreamableClientTransportOptions
595598
}
596599

600+
// StreamableReconnectOptions defines parameters for client reconnect attempts.
601+
type StreamableReconnectOptions struct {
602+
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
603+
// A value of 0 or less means never retry.
604+
MaxRetries int
605+
606+
// growFactor is the multiplicative factor by which the delay increases after each attempt.
607+
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
608+
// It must be 1.0 or greater if MaxRetries is greater than 0.
609+
growFactor float64
610+
611+
// initialDelay is the base delay for the first reconnect attempt.
612+
initialDelay time.Duration
613+
614+
// maxDelay caps the backoff delay, preventing it from growing indefinitely.
615+
maxDelay time.Duration
616+
}
617+
618+
// DefaultReconnectOptions provides sensible defaults for reconnect logic.
619+
var DefaultReconnectOptions = &StreamableReconnectOptions{
620+
MaxRetries: 5,
621+
growFactor: 1.5,
622+
initialDelay: 1 * time.Second,
623+
maxDelay: 30 * time.Second,
624+
}
625+
597626
// StreamableClientTransportOptions provides options for the
598627
// [NewStreamableClientTransport] constructor.
599628
type StreamableClientTransportOptions struct {
600629
// HTTPClient is the client to use for making HTTP requests. If nil,
601630
// http.DefaultClient is used.
602-
HTTPClient *http.Client
631+
HTTPClient *http.Client
632+
ReconnectOptions *StreamableReconnectOptions
603633
}
604634

605635
// NewStreamableClientTransport returns a new client transport that connects to
@@ -625,22 +655,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
625655
if client == nil {
626656
client = http.DefaultClient
627657
}
628-
return &streamableClientConn{
629-
url: t.url,
630-
client: client,
631-
incoming: make(chan []byte, 100),
632-
done: make(chan struct{}),
633-
}, nil
658+
reconnOpts := t.opts.ReconnectOptions
659+
if reconnOpts == nil {
660+
reconnOpts = DefaultReconnectOptions
661+
}
662+
// Create a new cancellable context that will manage the connection's lifecycle.
663+
// This is crucial for cleanly shutting down the background SSE listener by
664+
// cancelling its blocking network operations, which prevents hangs on exit.
665+
connCtx, cancel := context.WithCancel(context.Background())
666+
conn := &streamableClientConn{
667+
url: t.url,
668+
client: client,
669+
incoming: make(chan []byte, 100),
670+
done: make(chan struct{}),
671+
ReconnectOptions: reconnOpts,
672+
ctx: connCtx,
673+
cancel: cancel,
674+
}
675+
return conn, nil
634676
}
635677

636678
type streamableClientConn struct {
637-
url string
638-
client *http.Client
639-
incoming chan []byte
640-
done chan struct{}
679+
url string
680+
client *http.Client
681+
incoming chan []byte
682+
done chan struct{}
683+
ReconnectOptions *StreamableReconnectOptions
641684

642685
closeOnce sync.Once
643686
closeErr error
687+
ctx context.Context
688+
cancel context.CancelFunc
644689

645690
mu sync.Mutex
646691
protocolVersion string
@@ -662,6 +707,12 @@ func (c *streamableClientConn) SessionID() string {
662707

663708
// Read implements the [Connection] interface.
664709
func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) {
710+
s.mu.Lock()
711+
err := s.err
712+
s.mu.Unlock()
713+
if err != nil {
714+
return nil, err
715+
}
665716
select {
666717
case <-ctx.Done():
667718
return nil, ctx.Err()
@@ -701,14 +752,26 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
701752
return err
702753
}
703754

755+
// The session has just been initialized.
704756
if sessionID == "" {
705757
// locked
706758
s._sessionID = gotSessionID
759+
// Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint.
760+
// This can be used to open an SSE stream, allowing the server to
761+
// communicate to the client, without the client first sending data via
762+
// HTTP POST.
763+
go s.establishSSE(&startSSEState{})
707764
}
708765

709766
return nil
710767
}
711768

769+
// startSSEState holds the state for initiating an SSE stream.
770+
type startSSEState struct {
771+
lastEventID string
772+
attempt int
773+
}
774+
712775
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) {
713776
data, err := jsonrpc2.EncodeMessage(msg)
714777
if err != nil {
@@ -742,7 +805,8 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
742805
sessionID = resp.Header.Get(sessionIDHeader)
743806
switch ct := resp.Header.Get("Content-Type"); ct {
744807
case "text/event-stream":
745-
go s.handleSSE(resp)
808+
// Section 2.1: The SSE stream is initiated after a POST.
809+
go s.handleSSE(resp, &startSSEState{})
746810
case "application/json":
747811
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
748812
resp.Body.Close()
@@ -754,17 +818,20 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
754818
return sessionID, nil
755819
}
756820

757-
func (s *streamableClientConn) handleSSE(resp *http.Response) {
821+
// handleSSE processes an incoming Server-Sent Events stream, pushing received messages to the client's channel.
822+
// If the stream breaks, it uses the last received event ID to automatically trigger the reconnect logic.
823+
func (s *streamableClientConn) handleSSE(resp *http.Response, opts *startSSEState) {
758824
defer resp.Body.Close()
759825

760826
done := make(chan struct{})
761827
go func() {
762828
defer close(done)
763829
for evt, err := range scanEvents(resp.Body) {
764830
if err != nil {
765-
// TODO: surface this error; possibly break the stream
831+
s.scheduleReconnect(opts)
766832
return
767833
}
834+
opts.lastEventID = evt.id
768835
select {
769836
case <-s.done:
770837
return
@@ -782,6 +849,8 @@ func (s *streamableClientConn) handleSSE(resp *http.Response) {
782849
// Close implements the [Connection] interface.
783850
func (s *streamableClientConn) Close() error {
784851
s.closeOnce.Do(func() {
852+
// Cancel any hanging network requests.
853+
s.cancel()
785854
close(s.done)
786855

787856
req, err := http.NewRequest(http.MethodDelete, s.url, nil)
@@ -800,3 +869,86 @@ func (s *streamableClientConn) Close() error {
800869
})
801870
return s.closeErr
802871
}
872+
873+
// establishSSE creates and manages the persistent SSE listening stream.
874+
// It is used for both the initial connection and all subsequent reconnect attempts,
875+
// using the Last-Event-ID header to resume a broken stream where it left off.
876+
// On a successful response, it delegates to handleSSE to process events;
877+
// on failure, it triggers the client's reconnect logic.
878+
func (s *streamableClientConn) establishSSE(opts *startSSEState) {
879+
select {
880+
case <-s.done:
881+
return
882+
default:
883+
}
884+
885+
req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil)
886+
if err != nil {
887+
return
888+
}
889+
s.mu.Lock()
890+
if s._sessionID != "" {
891+
req.Header.Set("Mcp-Session-Id", s._sessionID)
892+
}
893+
s.mu.Unlock()
894+
if opts.lastEventID != "" {
895+
req.Header.Set("Last-Event-ID", opts.lastEventID)
896+
}
897+
req.Header.Set("Accept", "text/event-stream")
898+
899+
resp, err := s.client.Do(req)
900+
if err != nil {
901+
// On connection error, schedule a retry.
902+
s.scheduleReconnect(opts)
903+
return
904+
}
905+
906+
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
907+
if resp.StatusCode == http.StatusMethodNotAllowed {
908+
resp.Body.Close()
909+
return
910+
}
911+
912+
if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
913+
resp.Body.Close()
914+
return
915+
}
916+
917+
s.handleSSE(resp, opts)
918+
}
919+
920+
// scheduleReconnect schedules the next SSE reconnect attempt after a delay.
921+
func (s *streamableClientConn) scheduleReconnect(opts *startSSEState) {
922+
reconnOpts := s.ReconnectOptions
923+
if opts.attempt >= reconnOpts.MaxRetries {
924+
s.mu.Lock()
925+
s.err = fmt.Errorf("connection failed after %d attempts", reconnOpts.MaxRetries)
926+
s.mu.Unlock()
927+
s.Close() // Close the connection to unblock any readers.
928+
return
929+
}
930+
931+
delay := calculateReconnectDelay(reconnOpts, opts.attempt)
932+
933+
select {
934+
case <-s.done:
935+
return
936+
case <-time.After(delay):
937+
opts.attempt++
938+
s.establishSSE(opts)
939+
}
940+
}
941+
942+
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
943+
func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration {
944+
// Calculate the exponential backoff using the grow factor.
945+
backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt)))
946+
947+
// Cap the backoffDuration at maxDelay.
948+
backoffDuration = min(backoffDuration, opts.maxDelay)
949+
950+
// Use a full jitter using backoffDuration
951+
jitter := rand.N(backoffDuration)
952+
953+
return backoffDuration + jitter
954+
}

0 commit comments

Comments
 (0)