@@ -20,7 +20,9 @@ import (
2020 "time"
2121
2222 "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
23+ "github.com/modelcontextprotocol/go-sdk/internal/util"
2324 "github.com/modelcontextprotocol/go-sdk/jsonrpc"
25+ "golang.org/x/oauth2/authhandler"
2426)
2527
2628const (
@@ -683,7 +685,7 @@ type StreamableReconnectOptions struct {
683685}
684686
685687// DefaultReconnectOptions provides sensible defaults for reconnect logic.
686- var DefaultReconnectOptions = & StreamableReconnectOptions {
688+ var DefaultReconnectOptions = StreamableReconnectOptions {
687689 MaxRetries : 5 ,
688690 growFactor : 1.5 ,
689691 initialDelay : 1 * time .Second ,
@@ -693,10 +695,18 @@ var DefaultReconnectOptions = &StreamableReconnectOptions{
693695// StreamableClientTransportOptions provides options for the
694696// [NewStreamableClientTransport] constructor.
695697type StreamableClientTransportOptions struct {
696- // HTTPClient is the client to use for making HTTP requests. If nil,
697- // http.DefaultClient is used.
698- HTTPClient * http.Client
699- ReconnectOptions * StreamableReconnectOptions
698+ // ReconnectOptions control the transport's behavior when it is disconnected
699+ // from the server.
700+ ReconnectOptions StreamableReconnectOptions
701+ // HTTPClient is the client to use for making unauthenticaed HTTP requests.
702+ // If nil, http.DefaultClient is used.
703+ // For authenticated requests, a shallow clone of the client will be used,
704+ // with a different transport. The cookie jar will not be copied.
705+ HTTPClient * http.Client
706+ // AuthHandler is a function that handles the user interaction part of the OAuth 2.1 flow.
707+ // It should prompt the user at the given URL and return the expected OAuth values.
708+ // See [authhandler.AuthorizationHandler] for more.
709+ AuthHandler authhandler.AuthorizationHandler
700710}
701711
702712// NewStreamableClientTransport returns a new client transport that connects to
@@ -706,6 +716,12 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt
706716 if opts != nil {
707717 t .opts = * opts
708718 }
719+ if t .opts .HTTPClient == nil {
720+ t .opts .HTTPClient = http .DefaultClient
721+ }
722+ if t .opts .ReconnectOptions == (StreamableReconnectOptions {}) {
723+ t .opts .ReconnectOptions = DefaultReconnectOptions
724+ }
709725 return t
710726}
711727
@@ -718,26 +734,17 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt
718734// When closed, the connection issues a DELETE request to terminate the logical
719735// session.
720736func (t * StreamableClientTransport ) Connect (ctx context.Context ) (Connection , error ) {
721- client := t .opts .HTTPClient
722- if client == nil {
723- client = http .DefaultClient
724- }
725- reconnOpts := t .opts .ReconnectOptions
726- if reconnOpts == nil {
727- reconnOpts = DefaultReconnectOptions
728- }
729737 // Create a new cancellable context that will manage the connection's lifecycle.
730738 // This is crucial for cleanly shutting down the background SSE listener by
731739 // cancelling its blocking network operations, which prevents hangs on exit.
732740 connCtx , cancel := context .WithCancel (context .Background ())
733741 conn := & streamableClientConn {
734- url : t .url ,
735- client : client ,
736- incoming : make (chan []byte , 100 ),
737- done : make (chan struct {}),
738- ReconnectOptions : reconnOpts ,
739- ctx : connCtx ,
740- cancel : cancel ,
742+ url : t .url ,
743+ opts : t .opts ,
744+ incoming : make (chan []byte , 100 ),
745+ done : make (chan struct {}),
746+ ctx : connCtx ,
747+ cancel : cancel ,
741748 }
742749 // Start the persistent SSE listener right away.
743750 // Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint.
@@ -749,11 +756,11 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
749756}
750757
751758type streamableClientConn struct {
752- url string
753- client * http. Client
754- incoming chan [] byte
755- done chan struct {}
756- ReconnectOptions * StreamableReconnectOptions
759+ url string
760+ opts StreamableClientTransportOptions
761+ authClient * http. Client
762+ incoming chan [] byte
763+ done chan struct {}
757764
758765 closeOnce sync.Once
759766 closeErr error
@@ -833,9 +840,11 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
833840 return nil
834841}
835842
836- // postMessage POSTs msg to the server and reads the response.
837- // It returns the session ID from the response.
838- func (s * streamableClientConn ) postMessage (ctx context.Context , sessionID string , msg jsonrpc.Message ) (string , error ) {
843+ // postMessage makes a POST request to the server with msg as the body.
844+ // It returns the session ID.
845+ func (s * streamableClientConn ) postMessage (ctx context.Context , sessionID string , msg jsonrpc.Message ) (_ string , err error ) {
846+ defer util .Wrapf (& err , "MCP client posting message, session ID %q" , sessionID )
847+
839848 data , err := jsonrpc2 .EncodeMessage (msg )
840849 if err != nil {
841850 return "" , err
@@ -854,14 +863,46 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
854863 req .Header .Set ("Content-Type" , "application/json" )
855864 req .Header .Set ("Accept" , "application/json, text/event-stream" )
856865
857- resp , err := s .client .Do (req )
866+ // Use an HTTP client that does authentication, if there is one.
867+ // Otherwise, use the one provided by the user.
868+ client := s .authClient
869+ if client == nil {
870+ client = s .opts .HTTPClient
871+ }
872+ // TODO: Resource Indicators, as in
873+ // https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization#resource-parameter-implementation
874+ resp , err := client .Do (req )
858875 if err != nil {
859876 return "" , err
860877 }
878+ bodyClosed := false // avoid a second call to Close: undefined behavior (see [io.Closer])
879+ defer func () {
880+ if resp != nil && ! bodyClosed {
881+ resp .Body .Close ()
882+ }
883+ }()
884+
885+ if resp .StatusCode == http .StatusUnauthorized {
886+ if client == s .authClient {
887+ return "" , errors .New ("got StatusUnauthorized when already authorized" )
888+ }
889+ tokenSource , err := doOauth (ctx , resp .Header , s .opts .HTTPClient , s .opts .AuthHandler )
890+ if err != nil {
891+ return "" , err
892+ }
893+ s .authClient = newAuthClient (s .opts .HTTPClient , tokenSource )
894+ resp .Body .Close () // because we're about to replace resp
895+ resp , err = s .authClient .Do (req )
896+ if err != nil {
897+ return "" , err
898+ }
899+ if resp .StatusCode == http .StatusUnauthorized {
900+ return "" , errors .New ("got StatusUnauthorized just after authorization" )
901+ }
902+ }
861903
862904 if resp .StatusCode < 200 || resp .StatusCode >= 300 {
863905 // TODO: do a best effort read of the body here, and format it in the error.
864- resp .Body .Close ()
865906 return "" , fmt .Errorf ("broken session: %v" , resp .Status )
866907 }
867908
@@ -883,7 +924,6 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
883924 }
884925 return sessionID , nil
885926 default :
886- resp .Body .Close ()
887927 return "" , fmt .Errorf ("unsupported content type %q" , ct )
888928 }
889929 return sessionID , nil
@@ -960,12 +1000,13 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s
9601000// an error if all retries are exhausted.
9611001func (s * streamableClientConn ) reconnect (lastEventID string ) (* http.Response , error ) {
9621002 var finalErr error
1003+ maxRetries := s .opts .ReconnectOptions .MaxRetries
9631004
964- for attempt := 0 ; attempt < s . ReconnectOptions . MaxRetries ; attempt ++ {
1005+ for attempt := 0 ; attempt < maxRetries ; attempt ++ {
9651006 select {
9661007 case <- s .done :
9671008 return nil , fmt .Errorf ("connection closed by client during reconnect" )
968- case <- time .After (calculateReconnectDelay (s .ReconnectOptions , attempt )):
1009+ case <- time .After (calculateReconnectDelay (& s . opts .ReconnectOptions , attempt )):
9691010 resp , err := s .establishSSE (lastEventID )
9701011 if err != nil {
9711012 finalErr = err // Store the error and try again.
@@ -983,9 +1024,9 @@ func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
9831024 }
9841025 // If the loop completes, all retries have failed.
9851026 if finalErr != nil {
986- return nil , fmt .Errorf ("connection failed after %d attempts: %w" , s . ReconnectOptions . MaxRetries , finalErr )
1027+ return nil , fmt .Errorf ("connection failed after %d attempts: %w" , maxRetries , finalErr )
9871028 }
988- return nil , fmt .Errorf ("connection failed after %d attempts" , s . ReconnectOptions . MaxRetries )
1029+ return nil , fmt .Errorf ("connection failed after %d attempts" , maxRetries )
9891030}
9901031
9911032// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed.
@@ -1014,7 +1055,7 @@ func (s *streamableClientConn) Close() error {
10141055 req .Header .Set (protocolVersionHeader , s .protocolVersion )
10151056 }
10161057 req .Header .Set (sessionIDHeader , s ._sessionID )
1017- if _ , err := s .client .Do (req ); err != nil {
1058+ if _ , err := s .opts . HTTPClient .Do (req ); err != nil {
10181059 s .closeErr = err
10191060 }
10201061 }
@@ -1040,7 +1081,7 @@ func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
10401081 }
10411082 req .Header .Set ("Accept" , "text/event-stream" )
10421083
1043- return s .client .Do (req )
1084+ return s .opts . HTTPClient .Do (req )
10441085}
10451086
10461087// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
0 commit comments