@@ -9,11 +9,14 @@ import (
99 "context"
1010 "fmt"
1111 "io"
12+ "math"
13+ "math/rand/v2"
1214 "net/http"
1315 "strconv"
1416 "strings"
1517 "sync"
1618 "sync/atomic"
19+ "time"
1720
1821 "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1922 "github.com/modelcontextprotocol/go-sdk/jsonrpc"
@@ -594,12 +597,38 @@ type StreamableClientTransport struct {
594597 opts StreamableClientTransportOptions
595598}
596599
600+ // StreamableClientReconnectionOptions defines parameters for client reconnection attempts.
601+ type StreamableClientReconnectionOptions struct {
602+ // InitialDelay is the base delay for the first reconnection attempt
603+ InitialDelay time.Duration // default: 1 second
604+
605+ // MaxDelay caps the backoff delay, preventing it from growing indefinitely.
606+ MaxDelay time.Duration // default: 30 seconds
607+
608+ // GrowFactor is the multiplicative factor by which the delay increases after each attempt.
609+ // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
610+ GrowFactor float64 // default: 1.5
611+
612+ // MaxRetries is the maximum number of times to attempt reconnection before giving up.
613+ // A value of 0 or less means never retry.
614+ MaxRetries int // default: 5
615+ }
616+
617+ // DefaultReconnectionOptions provides sensible defaults for reconnection logic.
618+ var DefaultReconnectionOptions = & StreamableClientReconnectionOptions {
619+ InitialDelay : 1 * time .Second ,
620+ MaxDelay : 30 * time .Second ,
621+ GrowFactor : 1.5 ,
622+ MaxRetries : 5 ,
623+ }
624+
597625// StreamableClientTransportOptions provides options for the
598626// [NewStreamableClientTransport] constructor.
599627type StreamableClientTransportOptions struct {
600628 // HTTPClient is the client to use for making HTTP requests. If nil,
601629 // http.DefaultClient is used.
602- HTTPClient * http.Client
630+ HTTPClient * http.Client
631+ ReconnectionOptions * StreamableClientReconnectionOptions
603632}
604633
605634// NewStreamableClientTransport returns a new client transport that connects to
@@ -625,19 +654,26 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
625654 if client == nil {
626655 client = http .DefaultClient
627656 }
628- return & streamableClientConn {
629- url : t .url ,
630- client : client ,
631- incoming : make (chan []byte , 100 ),
632- done : make (chan struct {}),
633- }, nil
657+ reconnOpts := t .opts .ReconnectionOptions
658+ if reconnOpts == nil {
659+ reconnOpts = DefaultReconnectionOptions
660+ }
661+ conn := & streamableClientConn {
662+ url : t .url ,
663+ client : client ,
664+ incoming : make (chan []byte , 100 ),
665+ done : make (chan struct {}),
666+ reconnectionOptions : reconnOpts ,
667+ }
668+ return conn , nil
634669}
635670
636671type streamableClientConn struct {
637- url string
638- client * http.Client
639- incoming chan []byte
640- done chan struct {}
672+ url string
673+ client * http.Client
674+ incoming chan []byte
675+ done chan struct {}
676+ reconnectionOptions * StreamableClientReconnectionOptions
641677
642678 closeOnce sync.Once
643679 closeErr error
@@ -704,11 +740,19 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
704740 if sessionID == "" {
705741 // locked
706742 s ._sessionID = gotSessionID
743+ // With the session now established, launch the persistent background listener for server-pushed events.
744+ go s .establishSSE (context .Background (), & startSSEOptions {})
707745 }
708746
709747 return nil
710748}
711749
750+ // startSSEOptions holds parameters for initiating an SSE stream.
751+ type startSSEOptions struct {
752+ lastEventID string
753+ attempt int
754+ }
755+
712756func (s * streamableClientConn ) postMessage (ctx context.Context , sessionID string , msg jsonrpc.Message ) (string , error ) {
713757 data , err := jsonrpc2 .EncodeMessage (msg )
714758 if err != nil {
@@ -742,7 +786,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
742786 sessionID = resp .Header .Get (sessionIDHeader )
743787 switch ct := resp .Header .Get ("Content-Type" ); ct {
744788 case "text/event-stream" :
745- go s .handleSSE (resp )
789+ go s .handleSSE (context . Background (), resp , & startSSEOptions {} )
746790 case "application/json" :
747791 // TODO: read the body and send to s.incoming (in a select that also receives from s.done).
748792 resp .Body .Close ()
@@ -754,17 +798,20 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
754798 return sessionID , nil
755799}
756800
757- func (s * streamableClientConn ) handleSSE (resp * http.Response ) {
801+ // handleSSE processes an incoming Server-Sent Events stream, pushing received messages to the client's channel.
802+ // If the stream breaks, it uses the last received event ID to automatically trigger the reconnection logic.
803+ func (s * streamableClientConn ) handleSSE (ctx context.Context , resp * http.Response , opts * startSSEOptions ) {
758804 defer resp .Body .Close ()
759805
760806 done := make (chan struct {})
761807 go func () {
762808 defer close (done )
763809 for evt , err := range scanEvents (resp .Body ) {
764810 if err != nil {
765- // TODO: surface this error; possibly break the stream
811+ s . scheduleReconnection ( ctx , opts )
766812 return
767813 }
814+ opts .lastEventID = evt .id
768815 select {
769816 case <- s .done :
770817 return
@@ -800,3 +847,85 @@ func (s *streamableClientConn) Close() error {
800847 })
801848 return s .closeErr
802849}
850+
851+ // establishSSE creates and manages the persistent SSE listening stream.
852+ // It is used for both the initial connection and all subsequent reconnection attempts,
853+ // using the Last-Event-ID header to resume a broken stream where it left off.
854+ // On a successful response, it delegates to handleSSE to process events;
855+ // on failure, it triggers the client's reconnection logic.
856+ func (s * streamableClientConn ) establishSSE (ctx context.Context , opts * startSSEOptions ) {
857+ select {
858+ case <- s .done :
859+ return
860+ case <- ctx .Done ():
861+ return
862+ default :
863+ }
864+
865+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , s .url , nil )
866+ if err != nil {
867+ return
868+ }
869+ s .mu .Lock ()
870+ if s ._sessionID != "" {
871+ req .Header .Set ("Mcp-Session-Id" , s ._sessionID )
872+ }
873+ s .mu .Unlock ()
874+ if opts .lastEventID != "" {
875+ req .Header .Set ("Last-Event-ID" , opts .lastEventID )
876+ }
877+ req .Header .Set ("Accept" , "text/event-stream" )
878+
879+ resp , err := s .client .Do (req )
880+ if err != nil {
881+ // On connection error, schedule a retry.
882+ s .scheduleReconnection (ctx , opts )
883+ return
884+ }
885+
886+ // Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
887+ if resp .StatusCode == http .StatusMethodNotAllowed {
888+ resp .Body .Close ()
889+ return
890+ }
891+
892+ if ! strings .Contains (resp .Header .Get ("Content-Type" ), "text/event-stream" ) {
893+ resp .Body .Close ()
894+ s .scheduleReconnection (ctx , opts )
895+ return
896+ }
897+
898+ s .handleSSE (ctx , resp , opts )
899+ }
900+
901+ // scheduleReconnection schedules the next SSE reconnection attempt after a delay.
902+ func (s * streamableClientConn ) scheduleReconnection (ctx context.Context , opts * startSSEOptions ) {
903+ reconnOpts := s .reconnectionOptions
904+ if opts .attempt >= reconnOpts .MaxRetries {
905+ return
906+ }
907+
908+ delay := calculateReconnectionDelay (reconnOpts , opts .attempt )
909+
910+ select {
911+ case <- s .done :
912+ return
913+ case <- time .After (delay ):
914+ opts .attempt ++
915+ s .establishSSE (ctx , opts )
916+ }
917+ }
918+
919+ // calculateReconnectionDelay calculates a delay using exponential backoff with full jitter.
920+ func calculateReconnectionDelay (opts * StreamableClientReconnectionOptions , attempt int ) time.Duration {
921+ // Calculate the exponential backoff using the grow factor.
922+ backoffDuration := time .Duration (float64 (opts .InitialDelay ) * math .Pow (opts .GrowFactor , float64 (attempt )))
923+
924+ // Cap the backoffDuration at maxDelay.
925+ backoffDuration = min (backoffDuration , opts .MaxDelay )
926+
927+ // Use a full jitter using backoffDuration
928+ jitter := rand .N (backoffDuration )
929+
930+ return backoffDuration + jitter
931+ }
0 commit comments