diff --git a/auth/client.go b/auth/client.go new file mode 100644 index 00000000..534539c6 --- /dev/null +++ b/auth/client.go @@ -0,0 +1,104 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "log" + "net/http" + "sync" + + "github.com/modelcontextprotocol/go-sdk/internal/oauthex" + "golang.org/x/oauth2" +) + +// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization +// is approved, or an error if not. +type OAuthHandler func(context.Context, OAuthHandlerArgs) (oauth2.TokenSource, error) + +// OAuthHandlerArgs are arguments to an [OAuthHandler]. +type OAuthHandlerArgs struct { + // The URL to fetch protected resource metadata, extracted from the WWW-Authenticate header. + // Empty if not present or there was an error obtaining it. + ResourceMetadataURL string +} + +// HTTPTransport is an [http.RoundTripper] that follows the MCP +// OAuth protocol when it encounters a 401 Unauthorized response. +type HTTPTransport struct { + handler OAuthHandler + mu sync.Mutex // protects opts.Base + opts HTTPTransportOptions +} + +// NewHTTPTransport returns a new [*HTTPTransport]. +// The handler is invoked when an HTTP request results in a 401 Unauthorized status. +// It is called only once per transport. Once a TokenSource is obtained, it is used +// for the lifetime of the transport; subsequent 401s are not processed. +func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) { + t := &HTTPTransport{} + if opts != nil { + t.opts = *opts + } + if t.opts.Base == nil { + t.opts.Base = http.DefaultTransport + } + return t, nil +} + +// HTTPTransportOptions are options to [NewHTTPTransport]. +type HTTPTransportOptions struct { + // Base is the [http.RoundTripper] to use. + // If nil, [http.DefaultTransport] is used. + Base http.RoundTripper +} + +func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + base := t.opts.Base + _, haveTokenSource := base.(*oauth2.Transport) + t.mu.Unlock() + + resp, err := base.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if haveTokenSource { + // We failed to authorize even with a token source; give up. + return resp, nil + } + // Try to authorize. + t.mu.Lock() + // If we don't have a token source, get one by following the OAuth flow. + // (We may have obtained one while t.mu was not held above.) + if _, ok := t.opts.Base.(*oauth2.Transport); !ok { + authHeaders := resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")] + ts, err := t.handler(req.Context(), OAuthHandlerArgs{ + ResourceMetadataURL: extractResourceMetadataURL(authHeaders), + }) + if err != nil { + t.mu.Unlock() + return nil, err + } + t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} + } + t.mu.Unlock() + // Only one level of recursion, because we now have a token source. + return t.RoundTrip(req) +} + +func extractResourceMetadataURL(authHeaders []string) string { + cs, err := oauthex.ParseWWWAuthenticate(authHeaders) + if err != nil { + log.Printf("parsing auth headers %q: %v", authHeaders, err) + return "" + } + return oauthex.ResourceMetadataURL(cs) +} diff --git a/examples/server/auth-middleware/go.mod b/examples/server/auth-middleware/go.mod index cbc89c78..46f49f9c 100644 --- a/examples/server/auth-middleware/go.mod +++ b/examples/server/auth-middleware/go.mod @@ -1,6 +1,8 @@ module auth-middleware-example -go 1.23.0 +go 1.24.0 + +toolchain go1.24.4 require ( github.com/golang-jwt/jwt/v5 v5.2.2 @@ -10,6 +12,7 @@ require ( require ( github.com/google/jsonschema-go v0.3.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.31.0 // indirect ) replace github.com/modelcontextprotocol/go-sdk => ../../../ diff --git a/examples/server/auth-middleware/go.sum b/examples/server/auth-middleware/go.sum index 7b7a8e56..9d0a4841 100644 --- a/examples/server/auth-middleware/go.sum +++ b/examples/server/auth-middleware/go.sum @@ -6,5 +6,7 @@ github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIy github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo= +golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/examples/server/rate-limiting/go.mod b/examples/server/rate-limiting/go.mod index adf535b2..cb2c2d08 100644 --- a/examples/server/rate-limiting/go.mod +++ b/examples/server/rate-limiting/go.mod @@ -1,6 +1,8 @@ module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting -go 1.23.0 +go 1.24.0 + +toolchain go1.24.4 require ( github.com/modelcontextprotocol/go-sdk v0.3.0 @@ -10,6 +12,7 @@ require ( require ( github.com/google/jsonschema-go v0.3.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.31.0 // indirect ) replace github.com/modelcontextprotocol/go-sdk => ../../../ diff --git a/examples/server/rate-limiting/go.sum b/examples/server/rate-limiting/go.sum index 92c27394..f8a7fed3 100644 --- a/examples/server/rate-limiting/go.sum +++ b/examples/server/rate-limiting/go.sum @@ -4,6 +4,8 @@ github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIy github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo= +golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/go.mod b/go.mod index b78c25e7..f2b258d0 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/modelcontextprotocol/go-sdk -go 1.23.0 +go 1.24.0 + +toolchain go1.24.4 require ( github.com/golang-jwt/jwt/v5 v5.2.2 @@ -9,3 +11,5 @@ require ( github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/tools v0.34.0 ) + +require golang.org/x/oauth2 v0.31.0 // indirect diff --git a/go.sum b/go.sum index 2006a674..ae84ba5e 100644 --- a/go.sum +++ b/go.sum @@ -10,5 +10,7 @@ github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIy github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo= +golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/internal/oauthex/resource_meta.go b/internal/oauthex/resource_meta.go index 71d52cde..2387b0b5 100644 --- a/internal/oauthex/resource_meta.go +++ b/internal/oauthex/resource_meta.go @@ -146,11 +146,11 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, header http.Hea if len(headers) == 0 { return nil, nil } - cs, err := parseWWWAuthenticate(headers) + cs, err := ParseWWWAuthenticate(headers) if err != nil { return nil, err } - url := resourceMetadataURL(cs) + url := ResourceMetadataURL(cs) if url == "" { return nil, nil } @@ -187,9 +187,9 @@ type challenge struct { Params map[string]string } -// resourceMetadataURL returns a resource metadata URL from the given challenges, +// ResourceMetadataURL returns a resource metadata URL from the given challenges, // or the empty string if there is none. -func resourceMetadataURL(cs []challenge) string { +func ResourceMetadataURL(cs []challenge) string { for _, c := range cs { if u := c.Params["resource_metadata"]; u != "" { return u @@ -198,11 +198,11 @@ func resourceMetadataURL(cs []challenge) string { return "" } -// parseWWWAuthenticate parses a WWW-Authenticate header string. +// ParseWWWAuthenticate parses a WWW-Authenticate header string. // The header format is defined in RFC 9110, Section 11.6.1, and can contain // one or more challenges, separated by commas. // It returns a slice of challenges or an error if one of the headers is malformed. -func parseWWWAuthenticate(headers []string) ([]challenge, error) { +func ParseWWWAuthenticate(headers []string) ([]challenge, error) { // GENERATED BY GEMINI 2.5 (human-tweaked) var challenges []challenge for _, h := range headers {