Skip to content

Commit 067f241

Browse files
committed
add HTTP Request to TokenVerifier
The request might be needed to verify the token. Fixes #403.
1 parent 1c20560 commit 067f241

File tree

4 files changed

+16
-11
lines changed

4 files changed

+16
-11
lines changed

auth/auth.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ var ErrInvalidToken = errors.New("invalid token")
2626

2727
// A TokenVerifier checks the validity of a bearer token, and extracts information
2828
// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken.
29-
type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error)
29+
// The HTTP request is provided in case verifying the token involves checking it.
30+
type TokenVerifier func(ctx context.Context, token string, req *http.Request) (*TokenInfo, error)
3031

3132
// RequireBearerTokenOptions are options for [RequireBearerToken].
3233
type RequireBearerTokenOptions struct {
@@ -52,14 +53,16 @@ func TokenInfoFromContext(ctx context.Context) *TokenInfo {
5253
// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds.
5354
// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header
5455
// is populated to enable [protected resource metadata].
56+
//
57+
5558
//
5659
// [protected resource metadata]: https://datatracker.ietf.org/doc/rfc9728
5760
func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) func(http.Handler) http.Handler {
5861
// Based on typescript-sdk/src/server/auth/middleware/bearerAuth.ts.
5962

6063
return func(handler http.Handler) http.Handler {
6164
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
62-
tokenInfo, errmsg, code := verify(r.Context(), verifier, opts, r.Header.Get("Authorization"))
65+
tokenInfo, errmsg, code := verify(r, verifier, opts)
6366
if code != 0 {
6467
if code == http.StatusUnauthorized || code == http.StatusForbidden {
6568
if opts != nil && opts.ResourceMetadataURL != "" {
@@ -75,15 +78,16 @@ func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions)
7578
}
7679
}
7780

78-
func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerTokenOptions, authHeader string) (_ *TokenInfo, errmsg string, code int) {
81+
func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenOptions) (_ *TokenInfo, errmsg string, code int) {
7982
// Extract bearer token.
83+
authHeader := req.Header.Get("Authorization")
8084
fields := strings.Fields(authHeader)
8185
if len(fields) != 2 || strings.ToLower(fields[0]) != "bearer" {
8286
return nil, "no bearer token", http.StatusUnauthorized
8387
}
8488

8589
// Verify the token and get information from it.
86-
tokenInfo, err := verifier(ctx, fields[1])
90+
tokenInfo, err := verifier(req.Context(), fields[1], req)
8791
if err != nil {
8892
if errors.Is(err, ErrInvalidToken) {
8993
return nil, err.Error(), http.StatusUnauthorized

auth/auth_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ package auth
77
import (
88
"context"
99
"errors"
10+
"net/http"
1011
"testing"
1112
"time"
1213
)
1314

1415
func TestVerify(t *testing.T) {
15-
ctx := context.Background()
16-
verifier := func(_ context.Context, token string) (*TokenInfo, error) {
16+
verifier := func(_ context.Context, token string, _ *http.Request) (*TokenInfo, error) {
1717
switch token {
1818
case "valid":
1919
return &TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil
@@ -61,7 +61,9 @@ func TestVerify(t *testing.T) {
6161
},
6262
} {
6363
t.Run(tt.name, func(t *testing.T) {
64-
_, gotMsg, gotCode := verify(ctx, verifier, tt.opts, tt.header)
64+
_, gotMsg, gotCode := verify(&http.Request{
65+
Header: http.Header{"Authorization": {tt.header}},
66+
}, verifier, tt.opts)
6567
if gotMsg != tt.wantMsg || gotCode != tt.wantCode {
6668
t.Errorf("got (%q, %d), want (%q, %d)", gotMsg, gotCode, tt.wantMsg, tt.wantCode)
6769
}

examples/server/auth-middleware/main.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func generateToken(userID string, scopes []string, expiresIn time.Duration) (str
8383

8484
// verifyJWT verifies JWT tokens and returns TokenInfo for the auth middleware.
8585
// This function implements the TokenVerifier interface required by auth.RequireBearerToken.
86-
func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error) {
86+
func verifyJWT(ctx context.Context, tokenString string, _ *http.Request) (*auth.TokenInfo, error) {
8787
// Parse and validate the JWT token.
8888
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
8989
// Verify the signing method is HMAC.
@@ -92,7 +92,6 @@ func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error)
9292
}
9393
return jwtSecret, nil
9494
})
95-
9695
if err != nil {
9796
// Return standard error for invalid tokens.
9897
return nil, fmt.Errorf("%w: %v", auth.ErrInvalidToken, err)
@@ -111,7 +110,7 @@ func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error)
111110

112111
// verifyAPIKey verifies API keys and returns TokenInfo for the auth middleware.
113112
// This function implements the TokenVerifier interface required by auth.RequireBearerToken.
114-
func verifyAPIKey(ctx context.Context, apiKey string) (*auth.TokenInfo, error) {
113+
func verifyAPIKey(ctx context.Context, apiKey string, _ *http.Request) (*auth.TokenInfo, error) {
115114
// Look up the API key in our storage.
116115
key, exists := apiKeys[apiKey]
117116
if !exists {

mcp/streamable_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,7 @@ func TestTokenInfo(t *testing.T) {
12701270
AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo)
12711271

12721272
streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
1273-
verifier := func(context.Context, string) (*auth.TokenInfo, error) {
1273+
verifier := func(context.Context, string, *http.Request) (*auth.TokenInfo, error) {
12741274
return &auth.TokenInfo{
12751275
Scopes: []string{"scope"},
12761276
// Expiration is far, far in the future.

0 commit comments

Comments
 (0)