@@ -9,11 +9,14 @@ import (
99 "context"
1010 "fmt"
1111 "io"
12+ "math"
13+ "math/rand/v2"
1214 "net/http"
1315 "strconv"
1416 "strings"
1517 "sync"
1618 "sync/atomic"
19+ "time"
1720
1821 "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1922 "github.com/modelcontextprotocol/go-sdk/jsonrpc"
@@ -597,12 +600,39 @@ type StreamableClientTransport struct {
597600 opts StreamableClientTransportOptions
598601}
599602
603+ // StreamableReconnectOptions defines parameters for client reconnect attempts.
604+ type StreamableReconnectOptions struct {
605+ // MaxRetries is the maximum number of times to attempt a reconnect before giving up.
606+ // A value of 0 or less means never retry.
607+ MaxRetries int
608+
609+ // growFactor is the multiplicative factor by which the delay increases after each attempt.
610+ // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
611+ // It must be 1.0 or greater if MaxRetries is greater than 0.
612+ growFactor float64
613+
614+ // initialDelay is the base delay for the first reconnect attempt.
615+ initialDelay time.Duration
616+
617+ // maxDelay caps the backoff delay, preventing it from growing indefinitely.
618+ maxDelay time.Duration
619+ }
620+
621+ // DefaultReconnectOptions provides sensible defaults for reconnect logic.
622+ var DefaultReconnectOptions = & StreamableReconnectOptions {
623+ MaxRetries : 5 ,
624+ growFactor : 1.5 ,
625+ initialDelay : 1 * time .Second ,
626+ maxDelay : 30 * time .Second ,
627+ }
628+
600629// StreamableClientTransportOptions provides options for the
601630// [NewStreamableClientTransport] constructor.
602631type StreamableClientTransportOptions struct {
603632 // HTTPClient is the client to use for making HTTP requests. If nil,
604633 // http.DefaultClient is used.
605- HTTPClient * http.Client
634+ HTTPClient * http.Client
635+ ReconnectOptions * StreamableReconnectOptions
606636}
607637
608638// NewStreamableClientTransport returns a new client transport that connects to
@@ -628,22 +658,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
628658 if client == nil {
629659 client = http .DefaultClient
630660 }
631- return & streamableClientConn {
632- url : t .url ,
633- client : client ,
634- incoming : make (chan []byte , 100 ),
635- done : make (chan struct {}),
636- }, nil
661+ reconnOpts := t .opts .ReconnectOptions
662+ if reconnOpts == nil {
663+ reconnOpts = DefaultReconnectOptions
664+ }
665+ // Create a new cancellable context that will manage the connection's lifecycle.
666+ // This is crucial for cleanly shutting down the background SSE listener by
667+ // cancelling its blocking network operations, which prevents hangs on exit.
668+ connCtx , cancel := context .WithCancel (context .Background ())
669+ conn := & streamableClientConn {
670+ url : t .url ,
671+ client : client ,
672+ incoming : make (chan []byte , 100 ),
673+ done : make (chan struct {}),
674+ ReconnectOptions : reconnOpts ,
675+ ctx : connCtx ,
676+ cancel : cancel ,
677+ }
678+ return conn , nil
637679}
638680
639681type streamableClientConn struct {
640- url string
641- client * http.Client
642- incoming chan []byte
643- done chan struct {}
682+ url string
683+ client * http.Client
684+ incoming chan []byte
685+ done chan struct {}
686+ ReconnectOptions * StreamableReconnectOptions
644687
645688 closeOnce sync.Once
646689 closeErr error
690+ ctx context.Context
691+ cancel context.CancelFunc
647692
648693 mu sync.Mutex
649694 protocolVersion string
@@ -665,6 +710,12 @@ func (c *streamableClientConn) SessionID() string {
665710
666711// Read implements the [Connection] interface.
667712func (s * streamableClientConn ) Read (ctx context.Context ) (jsonrpc.Message , error ) {
713+ s .mu .Lock ()
714+ err := s .err
715+ s .mu .Unlock ()
716+ if err != nil {
717+ return nil , err
718+ }
668719 select {
669720 case <- ctx .Done ():
670721 return nil , ctx .Err ()
@@ -745,6 +796,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
745796 sessionID = resp .Header .Get (sessionIDHeader )
746797 switch ct := resp .Header .Get ("Content-Type" ); ct {
747798 case "text/event-stream" :
799+ // Section 2.1: The SSE stream is initiated after a POST.
748800 go s .handleSSE (resp )
749801 case "application/json" :
750802 // TODO: read the body and send to s.incoming (in a select that also receives from s.done).
@@ -757,34 +809,115 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
757809 return sessionID , nil
758810}
759811
760- func (s * streamableClientConn ) handleSSE (resp * http.Response ) {
812+ // handleSSE manages the entire lifecycle of an SSE connection. It processes
813+ // an incoming Server-Sent Events stream and automatically handles reconnection
814+ // logic if the stream breaks.
815+ func (s * streamableClientConn ) handleSSE (initialResp * http.Response ) {
816+ resp := initialResp
817+ var lastEventID string
818+
819+ for {
820+ eventID , clientClosed := s .processStream (resp )
821+ lastEventID = eventID
822+
823+ // If the connection was closed by the client, we're done.
824+ if clientClosed {
825+ return
826+ }
827+
828+ // The stream was interrupted or ended by the server. Attempt to reconnect.
829+ newResp , err := s .reconnect (lastEventID )
830+ if err != nil {
831+ // All reconnection attempts failed. Set the final error, close the
832+ // connection, and exit the goroutine.
833+ s .mu .Lock ()
834+ s .err = err
835+ s .mu .Unlock ()
836+ s .Close ()
837+ return
838+ }
839+
840+ // Reconnection was successful. Continue the loop with the new response.
841+ resp = newResp
842+ }
843+ }
844+
845+ // processStream reads from a single response body, sending events to the
846+ // incoming channel. It returns the ID of the last processed event, any error
847+ // that occurred, and a flag indicating if the connection was closed by the client.
848+ func (s * streamableClientConn ) processStream (resp * http.Response ) (lastEventID string , clientClosed bool ) {
761849 defer resp .Body .Close ()
762850
763- done := make (chan struct {})
764- go func () {
765- defer close (done )
766- for evt , err := range scanEvents (resp .Body ) {
851+ for evt , err := range scanEvents (resp .Body ) {
852+ if err != nil {
853+ return lastEventID , false
854+ }
855+
856+ if evt .ID != "" {
857+ lastEventID = evt .ID
858+ }
859+
860+ select {
861+ case s .incoming <- evt .Data :
862+ case <- s .done :
863+ // The connection was closed by the client; exit gracefully.
864+ return lastEventID , true
865+ }
866+ }
867+
868+ // The loop finished without an error, indicating the server closed the stream.
869+ // We'll attempt to reconnect, so this is not a client-side close.
870+ return lastEventID , false
871+ }
872+
873+ // reconnect handles the logic of retrying a connection with an exponential
874+ // backoff strategy. It returns a new, valid HTTP response if successful, or
875+ // an error if all retries are exhausted.
876+ func (s * streamableClientConn ) reconnect (lastEventID string ) (* http.Response , error ) {
877+ var finalErr error
878+
879+ for attempt := 0 ; attempt < s .ReconnectOptions .MaxRetries ; attempt ++ {
880+ select {
881+ case <- s .done :
882+ return nil , fmt .Errorf ("connection closed by client during reconnect" )
883+ case <- time .After (calculateReconnectDelay (s .ReconnectOptions , attempt )):
884+ resp , err := s .establishSSE (lastEventID )
767885 if err != nil {
768- // TODO: surface this error; possibly break the stream
769- return
886+ finalErr = err // Store the error and try again.
887+ continue
770888 }
771- select {
772- case <- s .done :
773- return
774- case s .incoming <- evt .Data :
889+
890+ if ! isResumable (resp ) {
891+ // The server indicated we should not continue.
892+ resp .Body .Close ()
893+ return nil , fmt .Errorf ("reconnection failed with unresumable status: %s" , resp .Status )
775894 }
895+
896+ return resp , nil
776897 }
777- }()
898+ }
899+ // If the loop completes, all retries have failed.
900+ if finalErr != nil {
901+ return nil , fmt .Errorf ("connection failed after %d attempts: %w" , s .ReconnectOptions .MaxRetries , finalErr )
902+ }
903+ return nil , fmt .Errorf ("connection failed after %d attempts" , s .ReconnectOptions .MaxRetries )
904+ }
778905
779- select {
780- case <- s .done :
781- case <- done :
906+ // isResumable checks if an HTTP response indicates a valid SSE stream that can be processed.
907+ func isResumable (resp * http.Response ) bool {
908+ // Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
909+ if resp .StatusCode == http .StatusMethodNotAllowed {
910+ return false
782911 }
912+
913+ return strings .Contains (resp .Header .Get ("Content-Type" ), "text/event-stream" )
783914}
784915
785916// Close implements the [Connection] interface.
786917func (s * streamableClientConn ) Close () error {
787918 s .closeOnce .Do (func () {
919+ // Cancel any hanging network requests.
920+ s .cancel ()
788921 close (s .done )
789922
790923 req , err := http .NewRequest (http .MethodDelete , s .url , nil )
@@ -803,3 +936,37 @@ func (s *streamableClientConn) Close() error {
803936 })
804937 return s .closeErr
805938}
939+
940+ // establishSSE establishes the persistent SSE listening stream.
941+ // It is used for reconnect attempts using the Last-Event-ID header to
942+ // resume a broken stream where it left off.
943+ func (s * streamableClientConn ) establishSSE (lastEventID string ) (* http.Response , error ) {
944+ req , err := http .NewRequestWithContext (s .ctx , http .MethodGet , s .url , nil )
945+ if err != nil {
946+ return nil , err
947+ }
948+ s .mu .Lock ()
949+ if s ._sessionID != "" {
950+ req .Header .Set ("Mcp-Session-Id" , s ._sessionID )
951+ }
952+ s .mu .Unlock ()
953+ if lastEventID != "" {
954+ req .Header .Set ("Last-Event-ID" , lastEventID )
955+ }
956+ req .Header .Set ("Accept" , "text/event-stream" )
957+
958+ return s .client .Do (req )
959+ }
960+
961+ // calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
962+ func calculateReconnectDelay (opts * StreamableReconnectOptions , attempt int ) time.Duration {
963+ // Calculate the exponential backoff using the grow factor.
964+ backoffDuration := time .Duration (float64 (opts .initialDelay ) * math .Pow (opts .growFactor , float64 (attempt )))
965+ // Cap the backoffDuration at maxDelay.
966+ backoffDuration = min (backoffDuration , opts .maxDelay )
967+
968+ // Use a full jitter using backoffDuration
969+ jitter := rand .N (backoffDuration )
970+
971+ return backoffDuration + jitter
972+ }
0 commit comments