Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ var ErrOAuth = errors.New("oauth error")

// A TokenVerifier checks the validity of a bearer token, and extracts information
// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken.
type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error)
// The HTTP request is provided in case verifying the token involves checking it.
type TokenVerifier func(ctx context.Context, token string, req *http.Request) (*TokenInfo, error)

// RequireBearerTokenOptions are options for [RequireBearerToken].
type RequireBearerTokenOptions struct {
Expand All @@ -55,14 +56,16 @@ func TokenInfoFromContext(ctx context.Context) *TokenInfo {
// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds.
// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header
// is populated to enable [protected resource metadata].
//

//
// [protected resource metadata]: https://datatracker.ietf.org/doc/rfc9728
func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) func(http.Handler) http.Handler {
// Based on typescript-sdk/src/server/auth/middleware/bearerAuth.ts.

return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenInfo, errmsg, code := verify(r.Context(), verifier, opts, r.Header.Get("Authorization"))
tokenInfo, errmsg, code := verify(r, verifier, opts)
if code != 0 {
if code == http.StatusUnauthorized || code == http.StatusForbidden {
if opts != nil && opts.ResourceMetadataURL != "" {
Expand All @@ -78,24 +81,23 @@ func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions)
}
}

func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerTokenOptions, authHeader string) (_ *TokenInfo, errmsg string, code int) {
func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenOptions) (_ *TokenInfo, errmsg string, code int) {
// Extract bearer token.
authHeader := req.Header.Get("Authorization")
fields := strings.Fields(authHeader)
if len(fields) != 2 || strings.ToLower(fields[0]) != "bearer" {
return nil, "no bearer token", http.StatusUnauthorized
}

// Verify the token and get information from it.
tokenInfo, err := verifier(ctx, fields[1])
tokenInfo, err := verifier(req.Context(), fields[1], req)
if err != nil {
if errors.Is(err, ErrInvalidToken) {
return nil, err.Error(), http.StatusUnauthorized
}
if errors.Is(err, ErrOAuth) {
return nil, err.Error(), http.StatusBadRequest
}
// Investigate how that works.
// See typescript-sdk/src/server/auth/middleware/bearerAuth.ts.
return nil, err.Error(), http.StatusInternalServerError
}

Expand Down
8 changes: 5 additions & 3 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ package auth
import (
"context"
"errors"
"net/http"
"testing"
"time"
)

func TestVerify(t *testing.T) {
ctx := context.Background()
verifier := func(_ context.Context, token string) (*TokenInfo, error) {
verifier := func(_ context.Context, token string, _ *http.Request) (*TokenInfo, error) {
switch token {
case "valid":
return &TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil
Expand Down Expand Up @@ -67,7 +67,9 @@ func TestVerify(t *testing.T) {
},
} {
t.Run(tt.name, func(t *testing.T) {
_, gotMsg, gotCode := verify(ctx, verifier, tt.opts, tt.header)
_, gotMsg, gotCode := verify(&http.Request{
Header: http.Header{"Authorization": {tt.header}},
}, verifier, tt.opts)
if gotMsg != tt.wantMsg || gotCode != tt.wantCode {
t.Errorf("got (%q, %d), want (%q, %d)", gotMsg, gotCode, tt.wantMsg, tt.wantCode)
}
Expand Down
5 changes: 2 additions & 3 deletions examples/server/auth-middleware/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func generateToken(userID string, scopes []string, expiresIn time.Duration) (str

// verifyJWT verifies JWT tokens and returns TokenInfo for the auth middleware.
// This function implements the TokenVerifier interface required by auth.RequireBearerToken.
func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error) {
func verifyJWT(ctx context.Context, tokenString string, _ *http.Request) (*auth.TokenInfo, error) {
// Parse and validate the JWT token.
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
// Verify the signing method is HMAC.
Expand All @@ -92,7 +92,6 @@ func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error)
}
return jwtSecret, nil
})

if err != nil {
// Return standard error for invalid tokens.
return nil, fmt.Errorf("%w: %v", auth.ErrInvalidToken, err)
Expand All @@ -111,7 +110,7 @@ func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error)

// verifyAPIKey verifies API keys and returns TokenInfo for the auth middleware.
// This function implements the TokenVerifier interface required by auth.RequireBearerToken.
func verifyAPIKey(ctx context.Context, apiKey string) (*auth.TokenInfo, error) {
func verifyAPIKey(ctx context.Context, apiKey string, _ *http.Request) (*auth.TokenInfo, error) {
// Look up the API key in our storage.
key, exists := apiKeys[apiKey]
if !exists {
Expand Down
2 changes: 1 addition & 1 deletion mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@ func TestTokenInfo(t *testing.T) {
AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo)

streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
verifier := func(context.Context, string) (*auth.TokenInfo, error) {
verifier := func(context.Context, string, *http.Request) (*auth.TokenInfo, error) {
return &auth.TokenInfo{
Scopes: []string{"scope"},
// Expiration is far, far in the future.
Expand Down