Skip to content

Commit d87e349

Browse files
committed
mcp: add client-side OAuth flow (preliminary)
This is a preliminary implementation of OAuth 2.1 for the client. When a StreamableClientTransport encounters a 401 Unauthorized response from the server, it initiates the OAuth flow described in thec authorization section of the MCP spec (https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization). On success, the transport obtains an access token which it passes to all subsequent requests. Much remains to be done here: - Dynamic client registration is not implemented. Since it is optional, we also need another way of supplying the client ID and secret to this code. - Resource Indicators, as described in section 2.5.1 of the MCP spec. - There is no way for the user to provide a redirect URL. - All of this is unexported, so it is available only to our own StreamingClientTransport. We should add API so people can use it with their own transports. - And, of course, tests. We should test against fake implementations but also, if we can find any, real reference implementations.
1 parent f50dbe3 commit d87e349

File tree

5 files changed

+155
-40
lines changed

5 files changed

+155
-40
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ go 1.23.0
55
require (
66
github.com/google/go-cmp v0.7.0
77
github.com/yosida95/uritemplate/v3 v3.0.2
8+
golang.org/x/oauth2 v0.30.0
89
golang.org/x/tools v0.34.0
910
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
22
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
33
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
44
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
5+
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
6+
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
57
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
68
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=

internal/oauthex/resource_meta.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string,
145145
// If there is no URL in the request, it returns nil, nil.
146146
func GetProtectedResourceMetadataFromHeader(ctx context.Context, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) {
147147
defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader")
148-
headers := header[http.CanonicalHeaderKey("WWW-Authenticate")]
149-
if len(headers) == 0 {
148+
authHeaders := header[http.CanonicalHeaderKey("WWW-Authenticate")]
149+
if len(authHeaders) == 0 {
150150
return nil, nil
151151
}
152-
cs, err := parseWWWAuthenticate(headers)
152+
cs, err := parseWWWAuthenticate(authHeaders)
153153
if err != nil {
154154
return nil, err
155155
}

mcp/auth.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright 2025 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package mcp
6+
7+
import (
8+
"context"
9+
"fmt"
10+
"net/http"
11+
12+
"github.com/modelcontextprotocol/go-sdk/internal/oauthex"
13+
"golang.org/x/oauth2"
14+
"golang.org/x/oauth2/authhandler"
15+
)
16+
17+
// newAuthClient returns a shallow copy of c with its tranport replaced by one that
18+
// authorizes with the token source.
19+
func newAuthClient(c *http.Client, ts oauth2.TokenSource) *http.Client {
20+
c2 := *c
21+
c2.Transport = &oauth2.Transport{
22+
Base: c.Transport,
23+
Source: ts,
24+
}
25+
return &c2
26+
}
27+
28+
// doOauth runs the OAuth 2.1 flow for MCP as described in
29+
// https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization.
30+
// It returns the resulting TokenSource.
31+
func doOauth(ctx context.Context, header http.Header, c *http.Client, oauthHandler authhandler.AuthorizationHandler) (oauth2.TokenSource, error) {
32+
prm, err := oauthex.GetProtectedResourceMetadataFromHeader(ctx, header, c)
33+
if err != nil {
34+
return nil, err
35+
}
36+
if len(prm.AuthorizationServers) == 0 {
37+
return nil, fmt.Errorf("resource %s provided no authorization servers", prm.Resource)
38+
}
39+
// TODO: try more than one?
40+
authServer := prm.AuthorizationServers[0]
41+
// TODO: which scopes to ask for? All of them?
42+
scopes := prm.ScopesSupported
43+
asm, err := oauthex.GetAuthServerMeta(ctx, authServer, c)
44+
if err != nil {
45+
return nil, err
46+
}
47+
// TODO: register the client with the auth server if not registered yet,
48+
// or find another way to get the client ID and secret.
49+
50+
// Get an access token from the auth server.
51+
config := &oauth2.Config{
52+
ClientID: "TODO: from registration",
53+
ClientSecret: "TODO: from registration",
54+
Endpoint: oauth2.Endpoint{
55+
AuthURL: asm.AuthorizationEndpoint,
56+
TokenURL: asm.TokenEndpoint,
57+
// DeviceAuthURL: "",
58+
// AuthStyle: "from auth meta?",
59+
},
60+
RedirectURL: "", // ???
61+
Scopes: scopes,
62+
}
63+
v := oauth2.GenerateVerifier()
64+
pkceParams := authhandler.PKCEParams{
65+
ChallengeMethod: "S256",
66+
Challenge: oauth2.S256ChallengeFromVerifier(v),
67+
Verifier: v,
68+
}
69+
state := randText()
70+
return authhandler.TokenSourceWithPKCE(ctx, config, state, oauthHandler, &pkceParams), nil
71+
}

mcp/streamable.go

Lines changed: 78 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ import (
2020
"time"
2121

2222
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
23+
"github.com/modelcontextprotocol/go-sdk/internal/util"
2324
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
25+
"golang.org/x/oauth2/authhandler"
2426
)
2527

2628
const (
@@ -683,7 +685,7 @@ type StreamableReconnectOptions struct {
683685
}
684686

685687
// DefaultReconnectOptions provides sensible defaults for reconnect logic.
686-
var DefaultReconnectOptions = &StreamableReconnectOptions{
688+
var DefaultReconnectOptions = StreamableReconnectOptions{
687689
MaxRetries: 5,
688690
growFactor: 1.5,
689691
initialDelay: 1 * time.Second,
@@ -693,10 +695,18 @@ var DefaultReconnectOptions = &StreamableReconnectOptions{
693695
// StreamableClientTransportOptions provides options for the
694696
// [NewStreamableClientTransport] constructor.
695697
type StreamableClientTransportOptions struct {
696-
// HTTPClient is the client to use for making HTTP requests. If nil,
697-
// http.DefaultClient is used.
698-
HTTPClient *http.Client
699-
ReconnectOptions *StreamableReconnectOptions
698+
// ReconnectOptions control the transport's behavior when it is disconnected
699+
// from the server.
700+
ReconnectOptions StreamableReconnectOptions
701+
// HTTPClient is the client to use for making unauthenticaed HTTP requests.
702+
// If nil, http.DefaultClient is used.
703+
// For authenticated requests, a shallow clone of the client will be used,
704+
// with a different transport. The cookie jar will not be copied.
705+
HTTPClient *http.Client
706+
// AuthHandler is a function that handles the user interaction part of the OAuth 2.1 flow.
707+
// It should prompt the user at the given URL and return the expected OAuth values.
708+
// See [authhandler.AuthorizationHandler] for more.
709+
AuthHandler authhandler.AuthorizationHandler
700710
}
701711

702712
// NewStreamableClientTransport returns a new client transport that connects to
@@ -706,6 +716,12 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt
706716
if opts != nil {
707717
t.opts = *opts
708718
}
719+
if t.opts.HTTPClient == nil {
720+
t.opts.HTTPClient = http.DefaultClient
721+
}
722+
if t.opts.ReconnectOptions == (StreamableReconnectOptions{}) {
723+
t.opts.ReconnectOptions = DefaultReconnectOptions
724+
}
709725
return t
710726
}
711727

@@ -718,26 +734,17 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt
718734
// When closed, the connection issues a DELETE request to terminate the logical
719735
// session.
720736
func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, error) {
721-
client := t.opts.HTTPClient
722-
if client == nil {
723-
client = http.DefaultClient
724-
}
725-
reconnOpts := t.opts.ReconnectOptions
726-
if reconnOpts == nil {
727-
reconnOpts = DefaultReconnectOptions
728-
}
729737
// Create a new cancellable context that will manage the connection's lifecycle.
730738
// This is crucial for cleanly shutting down the background SSE listener by
731739
// cancelling its blocking network operations, which prevents hangs on exit.
732740
connCtx, cancel := context.WithCancel(context.Background())
733741
conn := &streamableClientConn{
734-
url: t.url,
735-
client: client,
736-
incoming: make(chan []byte, 100),
737-
done: make(chan struct{}),
738-
ReconnectOptions: reconnOpts,
739-
ctx: connCtx,
740-
cancel: cancel,
742+
url: t.url,
743+
opts: t.opts,
744+
incoming: make(chan []byte, 100),
745+
done: make(chan struct{}),
746+
ctx: connCtx,
747+
cancel: cancel,
741748
}
742749
// Start the persistent SSE listener right away.
743750
// Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint.
@@ -749,11 +756,11 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
749756
}
750757

751758
type streamableClientConn struct {
752-
url string
753-
client *http.Client
754-
incoming chan []byte
755-
done chan struct{}
756-
ReconnectOptions *StreamableReconnectOptions
759+
url string
760+
opts StreamableClientTransportOptions
761+
authClient *http.Client
762+
incoming chan []byte
763+
done chan struct{}
757764

758765
closeOnce sync.Once
759766
closeErr error
@@ -833,9 +840,11 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
833840
return nil
834841
}
835842

836-
// postMessage POSTs msg to the server and reads the response.
837-
// It returns the session ID from the response.
838-
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) {
843+
// postMessage makes a POST request to the server with msg as the body.
844+
// It returns the session ID.
845+
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (_ string, err error) {
846+
defer util.Wrapf(&err, "MCP client posting message, session ID %q", sessionID)
847+
839848
data, err := jsonrpc2.EncodeMessage(msg)
840849
if err != nil {
841850
return "", err
@@ -854,14 +863,46 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
854863
req.Header.Set("Content-Type", "application/json")
855864
req.Header.Set("Accept", "application/json, text/event-stream")
856865

857-
resp, err := s.client.Do(req)
866+
// Use an HTTP client that does authentication, if there is one.
867+
// Otherwise, use the one provided by the user.
868+
client := s.authClient
869+
if client == nil {
870+
client = s.opts.HTTPClient
871+
}
872+
// TODO: Resource Indicators, as in
873+
// https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization#resource-parameter-implementation
874+
resp, err := client.Do(req)
858875
if err != nil {
859876
return "", err
860877
}
878+
bodyClosed := false // avoid a second call to Close: undefined behavior (see [io.Closer])
879+
defer func() {
880+
if resp != nil && !bodyClosed {
881+
resp.Body.Close()
882+
}
883+
}()
884+
885+
if resp.StatusCode == http.StatusUnauthorized {
886+
if client == s.authClient {
887+
return "", errors.New("got StatusUnauthorized when already authorized")
888+
}
889+
tokenSource, err := doOauth(ctx, resp.Header, s.opts.HTTPClient, s.opts.AuthHandler)
890+
if err != nil {
891+
return "", err
892+
}
893+
s.authClient = newAuthClient(s.opts.HTTPClient, tokenSource)
894+
resp.Body.Close() // because we're about to replace resp
895+
resp, err = s.authClient.Do(req)
896+
if err != nil {
897+
return "", err
898+
}
899+
if resp.StatusCode == http.StatusUnauthorized {
900+
return "", errors.New("got StatusUnauthorized just after authorization")
901+
}
902+
}
861903

862904
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
863905
// TODO: do a best effort read of the body here, and format it in the error.
864-
resp.Body.Close()
865906
return "", fmt.Errorf("broken session: %v", resp.Status)
866907
}
867908

@@ -883,7 +924,6 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
883924
}
884925
return sessionID, nil
885926
default:
886-
resp.Body.Close()
887927
return "", fmt.Errorf("unsupported content type %q", ct)
888928
}
889929
return sessionID, nil
@@ -960,12 +1000,13 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s
9601000
// an error if all retries are exhausted.
9611001
func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) {
9621002
var finalErr error
1003+
maxRetries := s.opts.ReconnectOptions.MaxRetries
9631004

964-
for attempt := 0; attempt < s.ReconnectOptions.MaxRetries; attempt++ {
1005+
for attempt := 0; attempt < maxRetries; attempt++ {
9651006
select {
9661007
case <-s.done:
9671008
return nil, fmt.Errorf("connection closed by client during reconnect")
968-
case <-time.After(calculateReconnectDelay(s.ReconnectOptions, attempt)):
1009+
case <-time.After(calculateReconnectDelay(&s.opts.ReconnectOptions, attempt)):
9691010
resp, err := s.establishSSE(lastEventID)
9701011
if err != nil {
9711012
finalErr = err // Store the error and try again.
@@ -983,9 +1024,9 @@ func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
9831024
}
9841025
// If the loop completes, all retries have failed.
9851026
if finalErr != nil {
986-
return nil, fmt.Errorf("connection failed after %d attempts: %w", s.ReconnectOptions.MaxRetries, finalErr)
1027+
return nil, fmt.Errorf("connection failed after %d attempts: %w", maxRetries, finalErr)
9871028
}
988-
return nil, fmt.Errorf("connection failed after %d attempts", s.ReconnectOptions.MaxRetries)
1029+
return nil, fmt.Errorf("connection failed after %d attempts", maxRetries)
9891030
}
9901031

9911032
// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed.
@@ -1014,7 +1055,7 @@ func (s *streamableClientConn) Close() error {
10141055
req.Header.Set(protocolVersionHeader, s.protocolVersion)
10151056
}
10161057
req.Header.Set(sessionIDHeader, s._sessionID)
1017-
if _, err := s.client.Do(req); err != nil {
1058+
if _, err := s.opts.HTTPClient.Do(req); err != nil {
10181059
s.closeErr = err
10191060
}
10201061
}
@@ -1040,7 +1081,7 @@ func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
10401081
}
10411082
req.Header.Set("Accept", "text/event-stream")
10421083

1043-
return s.client.Do(req)
1084+
return s.opts.HTTPClient.Do(req)
10441085
}
10451086

10461087
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.

0 commit comments

Comments
 (0)