Skip to content

Commit bb33f9e

Browse files
committed
client oauth WIP
1 parent 15942be commit bb33f9e

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

auth/client.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
2+
// Use of this source code is governed by an MIT-style
3+
// license that can be found in the LICENSE file.
4+
5+
package auth
6+
7+
import (
8+
"context"
9+
"fmt"
10+
"net/http"
11+
"sync"
12+
13+
"golang.org/x/oauth2"
14+
"golang.org/x/oauth2/authhandler"
15+
)
16+
17+
// HTTPTransport is an [http.RoundTripper] that follows the MCP
18+
// OAuth protocol when it encounters a 401 Unauthorized response.
19+
type HTTPTransport struct {
20+
mu sync.Mutex
21+
opts HTTPTransportConfig
22+
}
23+
24+
func NewHTTPTransport(opts *HTTPTransportConfig) (*HTTPTransport, error) {
25+
t := &HTTPTransport{}
26+
if opts != nil {
27+
t.opts = *opts
28+
}
29+
if t.opts.Base == nil {
30+
t.opts.Base = http.DefaultTransport
31+
}
32+
if t.opts.OAuthClient == nil {
33+
t.opts.OAuthClient = http.DefaultClient
34+
}
35+
return t, nil
36+
}
37+
38+
type HTTPTransportConfig struct {
39+
AuthHandler authhandler.AuthorizationHandler
40+
// Base is the [http.RoundTripper] to use initially, before credentials are obtained.
41+
// (After the OAuth flow is completed, an [oauth2.Transport] with the resulting
42+
// [oauth2.TokenSource] is used.)
43+
// If nil, [http.DefaultTransport] is used.
44+
Base http.RoundTripper
45+
// OAuth is used for HTTP requests that are part of the OAuth protocol,
46+
// such as requests to the authorization server. If nil, http.DefaultClient
47+
// is used.
48+
OAuthClient *http.Client
49+
}
50+
51+
func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
52+
// baseRoundTrip calls RoundTrip on the base transport.
53+
// If we should do OAuth to fix a 401 Unauthorized, it returns nil, nil.
54+
baseRoundTrip := func() (*http.Response, error) {
55+
t.mu.Lock()
56+
base := t.opts.Base
57+
_, haveTokenSource := base.(*oauth2.Transport)
58+
t.mu.Unlock()
59+
60+
resp, err := base.RoundTrip(req)
61+
if err != nil {
62+
return nil, err
63+
}
64+
if resp.StatusCode != http.StatusUnauthorized {
65+
return resp, nil
66+
}
67+
if haveTokenSource {
68+
// We failed to authorize even with a token source; give up.
69+
return resp, nil
70+
}
71+
return nil, nil
72+
}
73+
74+
resp, err := baseRoundTrip()
75+
if resp != nil || err != nil {
76+
return resp, err
77+
}
78+
79+
// Try to authorize.
80+
t.mu.Lock()
81+
// If we don't have a token source, get one by following the OAuth flow.
82+
// (We may have obtained one while t.mu was not held above.)
83+
if _, ok := t.opts.Base.(*oauth2.Transport); !ok {
84+
ts, err := t.doOauth(req.Context(), resp.Header)
85+
if err != nil {
86+
t.mu.Unlock()
87+
return nil, err
88+
}
89+
t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts}
90+
}
91+
t.mu.Unlock()
92+
// This will not return (nil, nil), because once we have a TokenSource we never lose it.
93+
return baseRoundTrip()
94+
}
95+
96+
// doOauth runs the OAuth 2.1 flow for MCP as described in
97+
// https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization.
98+
// It returns the resulting TokenSource.
99+
func (t *HTTPTransport) doOauth(ctx context.Context, header http.Header) (oauth2.TokenSource, error) {
100+
prm, err := oauthex.GetProtectedResourceMetadataFromHeader(ctx, header, c)
101+
if err != nil {
102+
return nil, err
103+
}
104+
if len(prm.AuthorizationServers) == 0 {
105+
return nil, fmt.Errorf("resource %s provided no authorization servers", prm.Resource)
106+
}
107+
// TODO: try more than one?
108+
authServer := prm.AuthorizationServers[0]
109+
// TODO: which scopes to ask for? All of them?
110+
scopes := prm.ScopesSupported
111+
asm, err := oauthex.GetAuthServerMeta(ctx, authServer, c)
112+
if err != nil {
113+
return nil, err
114+
}
115+
// TODO: register the client with the auth server if not registered yet,
116+
// or find another way to get the client ID and secret.
117+
118+
// Get an access token from the auth server.
119+
config := &oauth2.Config{
120+
ClientID: "TODO: from registration",
121+
ClientSecret: "TODO: from registration",
122+
Endpoint: oauth2.Endpoint{
123+
AuthURL: asm.AuthorizationEndpoint,
124+
TokenURL: asm.TokenEndpoint,
125+
// DeviceAuthURL: "",
126+
// AuthStyle: "from auth meta?",
127+
},
128+
RedirectURL: "", // ???
129+
Scopes: scopes,
130+
}
131+
v := oauth2.GenerateVerifier()
132+
pkceParams := authhandler.PKCEParams{
133+
ChallengeMethod: "S256",
134+
Challenge: oauth2.S256ChallengeFromVerifier(v),
135+
Verifier: v,
136+
}
137+
state := randText()
138+
return authhandler.TokenSourceWithPKCE(ctx, config, state, oauthHandler, &pkceParams), nil
139+
}

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ require (
88
golang.org/x/oauth2 v0.30.0
99
golang.org/x/tools v0.34.0
1010
)
11+
12+
require golang.org/x/oauth2 v0.30.0 // indirect

0 commit comments

Comments
 (0)