9
9
10
10
from starlette .requests import HTTPConnection , Request
11
11
from starlette .exceptions import HTTPException
12
- from starlette .authentication import AuthCredentials , AuthenticationBackend , AuthenticationError , BaseUser , SimpleUser , UnauthenticatedUser
12
+ from starlette .authentication import AuthCredentials , AuthenticationBackend , AuthenticationError , BaseUser , SimpleUser , UnauthenticatedUser , has_required_scope
13
13
from starlette .middleware .authentication import AuthenticationMiddleware
14
+ from starlette .types import Scope
14
15
15
16
from mcp .server .auth .errors import InsufficientScopeError , InvalidTokenError , OAuthError
16
17
from mcp .server .auth .provider import OAuthServerProvider
@@ -34,22 +35,12 @@ class BearerAuthBackend(AuthenticationBackend):
34
35
def __init__ (
35
36
self ,
36
37
provider : OAuthServerProvider ,
37
- required_scopes : Optional [List [str ]] = None
38
38
):
39
- """
40
- Initialize the backend.
41
-
42
- Args:
43
- provider: Authentication provider to validate tokens
44
- required_scopes: Optional list of scopes that the token must have
45
- """
46
39
self .provider = provider
47
- self .required_scopes = required_scopes or []
48
40
49
41
async def authenticate (self , conn : HTTPConnection ):
50
42
51
43
if "Authorization" not in conn .headers :
52
- raise AuthenticationError ()
53
44
return None
54
45
55
46
auth_header = conn .headers ["Authorization" ]
@@ -61,14 +52,7 @@ async def authenticate(self, conn: HTTPConnection):
61
52
try :
62
53
# Validate the token with the provider
63
54
auth_info = await self .provider .verify_access_token (token )
64
-
65
- # Check if the token has all required scopes
66
- if self .required_scopes :
67
- has_all_scopes = all (scope in auth_info .scopes for scope in self .required_scopes )
68
- if not has_all_scopes :
69
- raise InsufficientScopeError ("Insufficient scope" )
70
-
71
- # Check if the token is expired
55
+
72
56
if auth_info .expires_at and auth_info .expires_at < int (time .time ()):
73
57
raise InvalidTokenError ("Token has expired" )
74
58
@@ -79,7 +63,7 @@ async def authenticate(self, conn: HTTPConnection):
79
63
return None
80
64
81
65
82
- class BearerAuthMiddleware :
66
+ class RequireAuthMiddleware :
83
67
"""
84
68
Middleware that requires a valid Bearer token in the Authorization header.
85
69
@@ -92,8 +76,7 @@ class BearerAuthMiddleware:
92
76
def __init__ (
93
77
self ,
94
78
app : Any ,
95
- provider : OAuthServerProvider ,
96
- required_scopes : Optional [List [str ]] = None
79
+ required_scopes : list [str ]
97
80
):
98
81
"""
99
82
Initialize the middleware.
@@ -103,18 +86,15 @@ def __init__(
103
86
provider: Authentication provider to validate tokens
104
87
required_scopes: Optional list of scopes that the token must have
105
88
"""
106
- self .app = AuthenticationMiddleware (
107
- app ,
108
- backend = BearerAuthBackend (provider , required_scopes )
109
- )
110
-
111
- async def __call__ (self , scope : Dict , receive : Callable , send : Callable ) -> None :
112
- """
113
- Process the request and validate the bearer token.
89
+ self .app = app
90
+ self .required_scopes = required_scopes
91
+
92
+ async def __call__ (self , scope : Scope , receive : Callable , send : Callable ) -> None :
93
+ auth_credentials = scope .get ('auth' )
114
94
115
- Args :
116
- scope: ASGI scope
117
- receive: ASGI receive function
118
- send: ASGI send function
119
- """
95
+ for required_scope in self . required_scopes :
96
+ # auth_credentials should always be provided; this is just paranoia
97
+ if auth_credentials is None or required_scope not in auth_credentials . scopes :
98
+ raise HTTPException ( status_code = 403 , detail = "Insufficient scope" )
99
+
120
100
await self .app (scope , receive , send )
0 commit comments