@@ -6,137 +6,97 @@ package auth
66
77import (
88 "context"
9- "fmt "
9+ "log "
1010 "net/http"
1111 "sync"
1212
1313 "github.com/modelcontextprotocol/go-sdk/internal/oauthex"
1414 "golang.org/x/oauth2"
15- "golang.org/x/oauth2/authhandler"
1615)
1716
17+ // An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization
18+ // is approved, or an error if not.
19+ type OAuthHandler func (context.Context , OAuthHandlerArgs ) (oauth2.TokenSource , error )
20+
21+ // OAuthHandlerArgs are arguments to an [OAuthHandler].
22+ type OAuthHandlerArgs struct {
23+ // The URL to fetch protected resource metadata, extracted from the WWW-Authenticate header.
24+ // Empty if not present or there was an error obtaining it.
25+ ResourceMetadataURL string
26+ }
27+
1828// HTTPTransport is an [http.RoundTripper] that follows the MCP
1929// OAuth protocol when it encounters a 401 Unauthorized response.
2030type HTTPTransport struct {
21- mu sync.Mutex
22- opts HTTPTransportConfig
31+ handler OAuthHandler
32+ mu sync.Mutex // protects opts.Base
33+ opts HTTPTransportOptions
2334}
2435
25- func NewHTTPTransport (opts * HTTPTransportConfig ) (* HTTPTransport , error ) {
36+ // NewHTTPTransport returns a new [*HTTPTransport].
37+ // The handler is invoked when an HTTP request results in a 401 Unauthorized status.
38+ // It is called only once per transport. Once a TokenSource is obtained, it is used
39+ // for the lifetime of the transport; subsequent 401s are not processed.
40+ func NewHTTPTransport (handler OAuthHandler , opts * HTTPTransportOptions ) (* HTTPTransport , error ) {
2641 t := & HTTPTransport {}
2742 if opts != nil {
2843 t .opts = * opts
2944 }
3045 if t .opts .Base == nil {
3146 t .opts .Base = http .DefaultTransport
3247 }
33- if t .opts .OAuthClient == nil {
34- t .opts .OAuthClient = http .DefaultClient
35- }
3648 return t , nil
3749}
3850
39- type HTTPTransportConfig struct {
40- // OAuthHandler is conducts the OAuth flow, using information obtained from the
41- // MCP server and the auth server that it refers to.
42- OAuthHandler func (context.Context , * oauthex.ProtectedResourceMetadata , * oauthex.AuthServerMeta ) (oauth2.TokenSource , error )
43- // Base is the [http.RoundTripper] to use initially, before credentials are obtained.
44- // (After the OAuth flow is completed, an [oauth2.Transport] with the resulting
45- // [oauth2.TokenSource] is used.)
51+ // HTTPTransportOptions are options to [NewHTTPTransport].
52+ type HTTPTransportOptions struct {
53+ // Base is the [http.RoundTripper] to use.
4654 // If nil, [http.DefaultTransport] is used.
4755 Base http.RoundTripper
48- // OAuth is used for HTTP requests that are part of the OAuth protocol,
49- // such as requests to the authorization server. If nil, http.DefaultClient
50- // is used.
51- OAuthClient * http.Client
5256}
5357
5458func (t * HTTPTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
55- // baseRoundTrip calls RoundTrip on the base transport.
56- // If we should do OAuth to fix a 401 Unauthorized, it returns nil, nil.
57- baseRoundTrip := func () (* http.Response , error ) {
58- t .mu .Lock ()
59- base := t .opts .Base
60- _ , haveTokenSource := base .(* oauth2.Transport )
61- t .mu .Unlock ()
59+ t .mu .Lock ()
60+ base := t .opts .Base
61+ _ , haveTokenSource := base .(* oauth2.Transport )
62+ t .mu .Unlock ()
6263
63- resp , err := base .RoundTrip (req )
64- if err != nil {
65- return nil , err
66- }
67- if resp .StatusCode != http .StatusUnauthorized {
68- return resp , nil
69- }
70- if haveTokenSource {
71- // We failed to authorize even with a token source; give up.
72- return resp , nil
73- }
74- return nil , nil
64+ resp , err := base .RoundTrip (req )
65+ if err != nil {
66+ return nil , err
7567 }
76-
77- resp , err := baseRoundTrip ()
78- if resp != nil || err != nil {
79- return resp , err
68+ if resp .StatusCode != http .StatusUnauthorized {
69+ return resp , nil
70+ }
71+ if haveTokenSource {
72+ // We failed to authorize even with a token source; give up.
73+ return resp , nil
8074 }
81-
8275 // Try to authorize.
8376 t .mu .Lock ()
8477 // If we don't have a token source, get one by following the OAuth flow.
8578 // (We may have obtained one while t.mu was not held above.)
8679 if _ , ok := t .opts .Base .(* oauth2.Transport ); ! ok {
87- ts , err := t .doOauth (req .Context (), resp .Header )
80+ authHeaders := resp .Header [http .CanonicalHeaderKey ("WWW-Authenticate" )]
81+ ts , err := t .handler (req .Context (), OAuthHandlerArgs {
82+ ResourceMetadataURL : extractResourceMetadataURL (authHeaders ),
83+ })
8884 if err != nil {
8985 t .mu .Unlock ()
9086 return nil , err
9187 }
9288 t .opts .Base = & oauth2.Transport {Base : t .opts .Base , Source : ts }
9389 }
9490 t .mu .Unlock ()
95- // This will not return (nil, nil), because once we have a TokenSource we never lose it .
96- return baseRoundTrip ( )
91+ // Only one level of recursion, because we now have a token source .
92+ return t . RoundTrip ( req )
9793}
9894
99- // doOauth runs the OAuth 2.1 flow for MCP as described in
100- // https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization.
101- // It returns the resulting TokenSource.
102- func (t * HTTPTransport ) doOauth (ctx context.Context , header http.Header ) (oauth2.TokenSource , error ) {
103- prm , err := oauthex .GetProtectedResourceMetadataFromHeader (ctx , header , c )
104- if err != nil {
105- return nil , err
106- }
107- if len (prm .AuthorizationServers ) == 0 {
108- return nil , fmt .Errorf ("resource %s provided no authorization servers" , prm .Resource )
109- }
110- // TODO: try more than one?
111- authServer := prm .AuthorizationServers [0 ]
112- // TODO: which scopes to ask for? All of them?
113- scopes := prm .ScopesSupported
114- asm , err := oauthex .GetAuthServerMeta (ctx , authServer , c )
95+ func extractResourceMetadataURL (authHeaders []string ) string {
96+ cs , err := oauthex .ParseWWWAuthenticate (authHeaders )
11597 if err != nil {
116- return nil , err
117- }
118- // TODO: register the client with the auth server if not registered yet,
119- // or find another way to get the client ID and secret.
120-
121- // Get an access token from the auth server.
122- config := & oauth2.Config {
123- ClientID : "TODO: from registration" ,
124- ClientSecret : "TODO: from registration" ,
125- Endpoint : oauth2.Endpoint {
126- AuthURL : asm .AuthorizationEndpoint ,
127- TokenURL : asm .TokenEndpoint ,
128- // DeviceAuthURL: "",
129- // AuthStyle: "from auth meta?",
130- },
131- RedirectURL : "" , // ???
132- Scopes : scopes ,
133- }
134- v := oauth2 .GenerateVerifier ()
135- pkceParams := authhandler.PKCEParams {
136- ChallengeMethod : "S256" ,
137- Challenge : oauth2 .S256ChallengeFromVerifier (v ),
138- Verifier : v ,
98+ log .Printf ("parsing auth headers %q: %v" , authHeaders , err )
99+ return ""
139100 }
140- state := randText ()
141- return authhandler .TokenSourceWithPKCE (ctx , config , state , oauthHandler , & pkceParams ), nil
101+ return oauthex .ResourceMetadataURL (cs )
142102}
0 commit comments