Skip to content

Commit 9e359fc

Browse files
committed
revised OAuthHandler
1 parent a694fd3 commit 9e359fc

File tree

2 files changed

+53
-93
lines changed

2 files changed

+53
-93
lines changed

auth/client.go

Lines changed: 47 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,137 +6,97 @@ package auth
66

77
import (
88
"context"
9-
"fmt"
9+
"log"
1010
"net/http"
1111
"sync"
1212

1313
"github.com/modelcontextprotocol/go-sdk/internal/oauthex"
1414
"golang.org/x/oauth2"
15-
"golang.org/x/oauth2/authhandler"
1615
)
1716

17+
// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization
18+
// is approved, or an error if not.
19+
type OAuthHandler func(context.Context, OAuthHandlerArgs) (oauth2.TokenSource, error)
20+
21+
// OAuthHandlerArgs are arguments to an [OAuthHandler].
22+
type OAuthHandlerArgs struct {
23+
// The URL to fetch protected resource metadata, extracted from the WWW-Authenticate header.
24+
// Empty if not present or there was an error obtaining it.
25+
ResourceMetadataURL string
26+
}
27+
1828
// HTTPTransport is an [http.RoundTripper] that follows the MCP
1929
// OAuth protocol when it encounters a 401 Unauthorized response.
2030
type HTTPTransport struct {
21-
mu sync.Mutex
22-
opts HTTPTransportConfig
31+
handler OAuthHandler
32+
mu sync.Mutex // protects opts.Base
33+
opts HTTPTransportOptions
2334
}
2435

25-
func NewHTTPTransport(opts *HTTPTransportConfig) (*HTTPTransport, error) {
36+
// NewHTTPTransport returns a new [*HTTPTransport].
37+
// The handler is invoked when an HTTP request results in a 401 Unauthorized status.
38+
// It is called only once per transport. Once a TokenSource is obtained, it is used
39+
// for the lifetime of the transport; subsequent 401s are not processed.
40+
func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) {
2641
t := &HTTPTransport{}
2742
if opts != nil {
2843
t.opts = *opts
2944
}
3045
if t.opts.Base == nil {
3146
t.opts.Base = http.DefaultTransport
3247
}
33-
if t.opts.OAuthClient == nil {
34-
t.opts.OAuthClient = http.DefaultClient
35-
}
3648
return t, nil
3749
}
3850

39-
type HTTPTransportConfig struct {
40-
// OAuthHandler is conducts the OAuth flow, using information obtained from the
41-
// MCP server and the auth server that it refers to.
42-
OAuthHandler func(context.Context, *oauthex.ProtectedResourceMetadata, *oauthex.AuthServerMeta) (oauth2.TokenSource, error)
43-
// Base is the [http.RoundTripper] to use initially, before credentials are obtained.
44-
// (After the OAuth flow is completed, an [oauth2.Transport] with the resulting
45-
// [oauth2.TokenSource] is used.)
51+
// HTTPTransportOptions are options to [NewHTTPTransport].
52+
type HTTPTransportOptions struct {
53+
// Base is the [http.RoundTripper] to use.
4654
// If nil, [http.DefaultTransport] is used.
4755
Base http.RoundTripper
48-
// OAuth is used for HTTP requests that are part of the OAuth protocol,
49-
// such as requests to the authorization server. If nil, http.DefaultClient
50-
// is used.
51-
OAuthClient *http.Client
5256
}
5357

5458
func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
55-
// baseRoundTrip calls RoundTrip on the base transport.
56-
// If we should do OAuth to fix a 401 Unauthorized, it returns nil, nil.
57-
baseRoundTrip := func() (*http.Response, error) {
58-
t.mu.Lock()
59-
base := t.opts.Base
60-
_, haveTokenSource := base.(*oauth2.Transport)
61-
t.mu.Unlock()
59+
t.mu.Lock()
60+
base := t.opts.Base
61+
_, haveTokenSource := base.(*oauth2.Transport)
62+
t.mu.Unlock()
6263

63-
resp, err := base.RoundTrip(req)
64-
if err != nil {
65-
return nil, err
66-
}
67-
if resp.StatusCode != http.StatusUnauthorized {
68-
return resp, nil
69-
}
70-
if haveTokenSource {
71-
// We failed to authorize even with a token source; give up.
72-
return resp, nil
73-
}
74-
return nil, nil
64+
resp, err := base.RoundTrip(req)
65+
if err != nil {
66+
return nil, err
7567
}
76-
77-
resp, err := baseRoundTrip()
78-
if resp != nil || err != nil {
79-
return resp, err
68+
if resp.StatusCode != http.StatusUnauthorized {
69+
return resp, nil
70+
}
71+
if haveTokenSource {
72+
// We failed to authorize even with a token source; give up.
73+
return resp, nil
8074
}
81-
8275
// Try to authorize.
8376
t.mu.Lock()
8477
// If we don't have a token source, get one by following the OAuth flow.
8578
// (We may have obtained one while t.mu was not held above.)
8679
if _, ok := t.opts.Base.(*oauth2.Transport); !ok {
87-
ts, err := t.doOauth(req.Context(), resp.Header)
80+
authHeaders := resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]
81+
ts, err := t.handler(req.Context(), OAuthHandlerArgs{
82+
ResourceMetadataURL: extractResourceMetadataURL(authHeaders),
83+
})
8884
if err != nil {
8985
t.mu.Unlock()
9086
return nil, err
9187
}
9288
t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts}
9389
}
9490
t.mu.Unlock()
95-
// This will not return (nil, nil), because once we have a TokenSource we never lose it.
96-
return baseRoundTrip()
91+
// Only one level of recursion, because we now have a token source.
92+
return t.RoundTrip(req)
9793
}
9894

99-
// doOauth runs the OAuth 2.1 flow for MCP as described in
100-
// https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization.
101-
// It returns the resulting TokenSource.
102-
func (t *HTTPTransport) doOauth(ctx context.Context, header http.Header) (oauth2.TokenSource, error) {
103-
prm, err := oauthex.GetProtectedResourceMetadataFromHeader(ctx, header, c)
104-
if err != nil {
105-
return nil, err
106-
}
107-
if len(prm.AuthorizationServers) == 0 {
108-
return nil, fmt.Errorf("resource %s provided no authorization servers", prm.Resource)
109-
}
110-
// TODO: try more than one?
111-
authServer := prm.AuthorizationServers[0]
112-
// TODO: which scopes to ask for? All of them?
113-
scopes := prm.ScopesSupported
114-
asm, err := oauthex.GetAuthServerMeta(ctx, authServer, c)
95+
func extractResourceMetadataURL(authHeaders []string) string {
96+
cs, err := oauthex.ParseWWWAuthenticate(authHeaders)
11597
if err != nil {
116-
return nil, err
117-
}
118-
// TODO: register the client with the auth server if not registered yet,
119-
// or find another way to get the client ID and secret.
120-
121-
// Get an access token from the auth server.
122-
config := &oauth2.Config{
123-
ClientID: "TODO: from registration",
124-
ClientSecret: "TODO: from registration",
125-
Endpoint: oauth2.Endpoint{
126-
AuthURL: asm.AuthorizationEndpoint,
127-
TokenURL: asm.TokenEndpoint,
128-
// DeviceAuthURL: "",
129-
// AuthStyle: "from auth meta?",
130-
},
131-
RedirectURL: "", // ???
132-
Scopes: scopes,
133-
}
134-
v := oauth2.GenerateVerifier()
135-
pkceParams := authhandler.PKCEParams{
136-
ChallengeMethod: "S256",
137-
Challenge: oauth2.S256ChallengeFromVerifier(v),
138-
Verifier: v,
98+
log.Printf("parsing auth headers %q: %v", authHeaders, err)
99+
return ""
139100
}
140-
state := randText()
141-
return authhandler.TokenSourceWithPKCE(ctx, config, state, oauthHandler, &pkceParams), nil
101+
return oauthex.ResourceMetadataURL(cs)
142102
}

internal/oauthex/resource_meta.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,11 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, header http.Hea
149149
if len(authHeaders) == 0 {
150150
return nil, nil
151151
}
152-
cs, err := parseWWWAuthenticate(authHeaders)
152+
cs, err := ParseWWWAuthenticate(authHeaders)
153153
if err != nil {
154154
return nil, err
155155
}
156-
url := resourceMetadataURL(cs)
156+
url := ResourceMetadataURL(cs)
157157
if url == "" {
158158
return nil, nil
159159
}
@@ -210,9 +210,9 @@ type challenge struct {
210210
Params map[string]string
211211
}
212212

213-
// resourceMetadataURL returns a resource metadata URL from the given challenges,
213+
// ResourceMetadataURL returns a resource metadata URL from the given challenges,
214214
// or the empty string if there is none.
215-
func resourceMetadataURL(cs []challenge) string {
215+
func ResourceMetadataURL(cs []challenge) string {
216216
for _, c := range cs {
217217
if u := c.Params["resource_metadata"]; u != "" {
218218
return u
@@ -221,11 +221,11 @@ func resourceMetadataURL(cs []challenge) string {
221221
return ""
222222
}
223223

224-
// parseWWWAuthenticate parses a WWW-Authenticate header string.
224+
// ParseWWWAuthenticate parses a WWW-Authenticate header string.
225225
// The header format is defined in RFC 9110, Section 11.6.1, and can contain
226226
// one or more challenges, separated by commas.
227227
// It returns a slice of challenges or an error if one of the headers is malformed.
228-
func parseWWWAuthenticate(headers []string) ([]challenge, error) {
228+
func ParseWWWAuthenticate(headers []string) ([]challenge, error) {
229229
// GENERATED BY GEMINI 2.5 (human-tweaked)
230230
var challenges []challenge
231231
for _, h := range headers {

0 commit comments

Comments
 (0)