Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions auth/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.

//go:build mcp_go_client_oauth

package auth

import (
"context"
"log"
"net/http"
"sync"

"github.com/modelcontextprotocol/go-sdk/internal/oauthex"
"golang.org/x/oauth2"
)

// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization
// is approved, or an error if not.
type OAuthHandler func(context.Context, OAuthHandlerArgs) (oauth2.TokenSource, error)

// OAuthHandlerArgs are arguments to an [OAuthHandler].
type OAuthHandlerArgs struct {
// The URL to fetch protected resource metadata, extracted from the WWW-Authenticate header.
// Empty if not present or there was an error obtaining it.
ResourceMetadataURL string
}

// HTTPTransport is an [http.RoundTripper] that follows the MCP
// OAuth protocol when it encounters a 401 Unauthorized response.
type HTTPTransport struct {
handler OAuthHandler
mu sync.Mutex // protects opts.Base
opts HTTPTransportOptions
}

// NewHTTPTransport returns a new [*HTTPTransport].
// The handler is invoked when an HTTP request results in a 401 Unauthorized status.
// It is called only once per transport. Once a TokenSource is obtained, it is used
// for the lifetime of the transport; subsequent 401s are not processed.
func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if handler is nil, otherwise we will panic when trying to initialize the token source in RoundTrip.

t := &HTTPTransport{}
if opts != nil {
t.opts = *opts
}
if t.opts.Base == nil {
t.opts.Base = http.DefaultTransport
}
return t, nil
}

// HTTPTransportOptions are options to [NewHTTPTransport].
type HTTPTransportOptions struct {
// Base is the [http.RoundTripper] to use.
// If nil, [http.DefaultTransport] is used.
Base http.RoundTripper
}

func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
t.mu.Lock()
base := t.opts.Base
_, haveTokenSource := base.(*oauth2.Transport)
t.mu.Unlock()

resp, err := base.RoundTrip(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusUnauthorized {
return resp, nil
}
if haveTokenSource {
// We failed to authorize even with a token source; give up.
return resp, nil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we return an error here explaining that we tried to authorize and it failed, or is that going to be handled higher in the stack?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Higher up. The caller will see the unauthorized status and should perform the OAuth dance again.

}
// Try to authorize.
t.mu.Lock()
// If we don't have a token source, get one by following the OAuth flow.
// (We may have obtained one while t.mu was not held above.)
if _, ok := t.opts.Base.(*oauth2.Transport); !ok {
authHeaders := resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]
ts, err := t.handler(req.Context(), OAuthHandlerArgs{
ResourceMetadataURL: extractResourceMetadataURL(authHeaders),
})
if err != nil {
t.mu.Unlock()
return nil, err
}
t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts}
}
t.mu.Unlock()
// Only one level of recursion, because we now have a token source.
return t.RoundTrip(req)
}

func extractResourceMetadataURL(authHeaders []string) string {
cs, err := oauthex.ParseWWWAuthenticate(authHeaders)
if err != nil {
log.Printf("parsing auth headers %q: %v", authHeaders, err)
return ""
}
return oauthex.ResourceMetadataURL(cs)
}
5 changes: 4 additions & 1 deletion examples/server/auth-middleware/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module auth-middleware-example

go 1.23.0
go 1.24.0

toolchain go1.24.4

require (
github.com/golang-jwt/jwt/v5 v5.2.2
Expand All @@ -10,6 +12,7 @@ require (
require (
github.com/google/jsonschema-go v0.3.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/oauth2 v0.31.0 // indirect
)

replace github.com/modelcontextprotocol/go-sdk => ../../../
2 changes: 2 additions & 0 deletions examples/server/auth-middleware/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIy
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo=
golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
5 changes: 4 additions & 1 deletion examples/server/rate-limiting/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting

go 1.23.0
go 1.24.0

toolchain go1.24.4

require (
github.com/modelcontextprotocol/go-sdk v0.3.0
Expand All @@ -10,6 +12,7 @@ require (
require (
github.com/google/jsonschema-go v0.3.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/oauth2 v0.31.0 // indirect
)

replace github.com/modelcontextprotocol/go-sdk => ../../../
2 changes: 2 additions & 0 deletions examples/server/rate-limiting/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIy
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo=
golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
Expand Down
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module github.com/modelcontextprotocol/go-sdk

go 1.23.0
go 1.24.0

toolchain go1.24.4

require (
github.com/golang-jwt/jwt/v5 v5.2.2
Expand All @@ -9,3 +11,5 @@ require (
github.com/yosida95/uritemplate/v3 v3.0.2
golang.org/x/tools v0.34.0
)

require golang.org/x/oauth2 v0.31.0 // indirect
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIy
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo=
golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
12 changes: 6 additions & 6 deletions internal/oauthex/resource_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, header http.Hea
if len(headers) == 0 {
return nil, nil
}
cs, err := parseWWWAuthenticate(headers)
cs, err := ParseWWWAuthenticate(headers)
if err != nil {
return nil, err
}
url := resourceMetadataURL(cs)
url := ResourceMetadataURL(cs)
if url == "" {
return nil, nil
}
Expand Down Expand Up @@ -187,9 +187,9 @@ type challenge struct {
Params map[string]string
}

// resourceMetadataURL returns a resource metadata URL from the given challenges,
// ResourceMetadataURL returns a resource metadata URL from the given challenges,
// or the empty string if there is none.
func resourceMetadataURL(cs []challenge) string {
func ResourceMetadataURL(cs []challenge) string {
for _, c := range cs {
if u := c.Params["resource_metadata"]; u != "" {
return u
Expand All @@ -198,11 +198,11 @@ func resourceMetadataURL(cs []challenge) string {
return ""
}

// parseWWWAuthenticate parses a WWW-Authenticate header string.
// ParseWWWAuthenticate parses a WWW-Authenticate header string.
// The header format is defined in RFC 9110, Section 11.6.1, and can contain
// one or more challenges, separated by commas.
// It returns a slice of challenges or an error if one of the headers is malformed.
func parseWWWAuthenticate(headers []string) ([]challenge, error) {
func ParseWWWAuthenticate(headers []string) ([]challenge, error) {
// GENERATED BY GEMINI 2.5 (human-tweaked)
var challenges []challenge
for _, h := range headers {
Expand Down
Loading