Skip to content

Commit 43d9b70

Browse files
authored
refactor: use context var and init whoami sample (#8)
* refactor: use context var and init whoami sample * refactor: update tests
1 parent 992d93c commit 43d9b70

17 files changed

+699
-306
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
strategy:
1111
fail-fast: false
1212
matrix:
13-
python-version: ["3.9", "3.10", "3.11", "3.12"]
13+
python-version: ["3.10", "3.11", "3.12", "3.13"]
1414
os: [ubuntu-latest]
1515
runs-on: ${{ matrix.os }}
1616

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.9
1+
3.10

mcpauth/__init__.py

Lines changed: 92 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
from contextvars import ContextVar
12
import logging
2-
from typing import List, Literal, Optional, Union
3+
from typing import Any, Callable, List, Literal, Optional, Union
34

45
from .middleware.create_bearer_auth import BearerAuthConfig
5-
from .types import VerifyAccessTokenFunction
6-
from .config import AuthServerConfig
6+
from .types import AuthInfo, VerifyAccessTokenFunction
7+
from .config import AuthServerConfig, ServerMetadataPaths
78
from .exceptions import MCPAuthAuthServerException, AuthServerExceptionCode
89
from .utils import validate_server_config
910
from starlette.middleware.base import BaseHTTPMiddleware
10-
from starlette.responses import JSONResponse
11+
from starlette.responses import Response, JSONResponse
12+
from starlette.requests import Request
13+
from starlette.routing import Route
14+
15+
_context_var_name = "mcp_auth_context"
1116

1217

1318
class MCPAuth:
@@ -18,9 +23,22 @@ class MCPAuth:
1823
See Also: https://mcp-auth.dev for more information about the library and its usage.
1924
"""
2025

21-
def __init__(self, server: AuthServerConfig):
26+
server: AuthServerConfig
27+
"""
28+
The configuration for the remote authorization server.
29+
"""
30+
31+
def __init__(
32+
self,
33+
server: AuthServerConfig,
34+
context_var: ContextVar[Optional[AuthInfo]] = ContextVar(
35+
_context_var_name, default=None
36+
),
37+
):
2238
"""
2339
:param server: Configuration for the remote authorization server.
40+
:param context_var: Context variable to store the `AuthInfo` object for the current request.
41+
By default, it will be created with the name "mcp_auth_context".
2442
"""
2543

2644
result = validate_server_config(server)
@@ -40,20 +58,78 @@ def __init__(self, server: AuthServerConfig):
4058
logging.warning(f"- {warning}")
4159

4260
self.server = server
61+
self._context_var = context_var
62+
63+
@property
64+
def auth_info(self) -> Optional[AuthInfo]:
65+
"""
66+
The current `AuthInfo` object from the context variable.
67+
68+
This is useful for accessing the authenticated user's information in later middleware or
69+
route handlers.
70+
:return: The current `AuthInfo` object, or `None` if not set.
71+
"""
72+
73+
return self._context_var.get()
4374

44-
def metadata_response(self) -> JSONResponse:
75+
def metadata_endpoint(self) -> Callable[[Request], Any]:
4576
"""
46-
Returns a response containing the server metadata in JSON format with CORS support.
77+
Returns a Starlette endpoint function that handles the OAuth 2.0 Authorization Metadata
78+
endpoint (`/.well-known/oauth-authorization-server`) with CORS support.
79+
80+
Example:
81+
```python
82+
from starlette.applications import Starlette
83+
from mcpauth import MCPAuth
84+
from mcpauth.config import ServerMetadataPaths
85+
86+
mcp_auth = MCPAuth(server=your_server_config)
87+
app = Starlette(routes=[
88+
Route(
89+
ServerMetadataPaths.OAUTH.value,
90+
mcp_auth.metadata_endpoint(),
91+
methods=["GET", "OPTIONS"] # Ensure to handle both GET and OPTIONS methods
92+
)
93+
])
94+
```
95+
"""
96+
97+
async def endpoint(request: Request) -> Response:
98+
if request.method == "OPTIONS":
99+
response = Response(status_code=204)
100+
else:
101+
server_config = self.server
102+
response = JSONResponse(
103+
server_config.metadata.model_dump(exclude_none=True),
104+
status_code=200,
105+
)
106+
response.headers["Access-Control-Allow-Origin"] = "*"
107+
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
108+
response.headers["Access-Control-Allow-Headers"] = "*"
109+
return response
110+
111+
return endpoint
112+
113+
def metadata_route(self) -> Route:
114+
"""
115+
Returns a Starlette route that handles the OAuth 2.0 Authorization Metadata endpoint
116+
(`/.well-known/oauth-authorization-server`) with CORS support.
117+
118+
Example:
119+
```python
120+
from starlette.applications import Starlette
121+
from mcpauth import MCPAuth
122+
123+
mcp_auth = MCPAuth(server=your_server_config)
124+
app = Starlette(routes=[mcp_auth.metadata_route()])
125+
```
47126
"""
48-
server_config = self.server
49127

50-
response = JSONResponse(
51-
server_config.metadata.model_dump(exclude_none=True),
52-
status_code=200,
128+
return Route(
129+
ServerMetadataPaths.OAUTH.value,
130+
self.metadata_endpoint(),
131+
methods=["GET", "OPTIONS"],
53132
)
54-
response.headers["Access-Control-Allow-Origin"] = "*"
55-
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
56-
return response
57133

58134
def bearer_auth_middleware(
59135
self,
@@ -101,10 +177,11 @@ def bearer_auth_middleware(
101177

102178
return create_bearer_auth(
103179
verify,
104-
BearerAuthConfig(
180+
config=BearerAuthConfig(
105181
issuer=metadata.issuer,
106182
audience=audience,
107183
required_scopes=required_scopes,
108184
show_error_details=show_error_details,
109185
),
186+
context_var=self._context_var,
110187
)

mcpauth/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ class AuthorizationServerMetadata(BaseModel):
100100
code challenge methods supported by this authorization server.
101101
"""
102102

103+
userinfo_endpoint: Optional[str] = None
104+
"""
105+
URL of the authorization server's UserInfo endpoint [[OpenID Connect](https://openid.net/specs/openid-connect-core-1_0.html#UserInfo)].
106+
"""
107+
103108

104109
class AuthServerType(str, Enum):
105110
"""

mcpauth/exceptions.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -143,31 +143,31 @@ def to_json(self, show_cause: bool = False) -> Dict[str, Optional[str]]:
143143
return {k: v for k, v in data.items() if v is not None}
144144

145145

146-
class MCPAuthJwtVerificationExceptionCode(str, Enum):
147-
INVALID_JWT = "invalid_jwt"
148-
JWT_VERIFICATION_FAILED = "jwt_verification_failed"
146+
class MCPAuthTokenVerificationExceptionCode(str, Enum):
147+
INVALID_TOKEN = "invalid_token"
148+
TOKEN_VERIFICATION_FAILED = "token_verification_failed"
149149

150150

151-
jwt_verification_exception_description: Dict[
152-
MCPAuthJwtVerificationExceptionCode, str
151+
token_verification_exception_description: Dict[
152+
MCPAuthTokenVerificationExceptionCode, str
153153
] = {
154-
MCPAuthJwtVerificationExceptionCode.INVALID_JWT: "The provided JWT is invalid or malformed.",
155-
MCPAuthJwtVerificationExceptionCode.JWT_VERIFICATION_FAILED: "JWT verification failed. The token could not be verified.",
154+
MCPAuthTokenVerificationExceptionCode.INVALID_TOKEN: "The provided token is invalid or malformed.",
155+
MCPAuthTokenVerificationExceptionCode.TOKEN_VERIFICATION_FAILED: "The token verification failed due to an error in the verification process.",
156156
}
157157

158158

159-
class MCPAuthJwtVerificationException(MCPAuthException):
159+
class MCPAuthTokenVerificationException(MCPAuthException):
160160
"""
161-
Exception thrown when there is an issue when verifying JWT tokens.
161+
Exception thrown when there is an issue when verifying access tokens.
162162
"""
163163

164164
def __init__(
165-
self, code: MCPAuthJwtVerificationExceptionCode, cause: ExceptionCause = None
165+
self, code: MCPAuthTokenVerificationExceptionCode, cause: ExceptionCause = None
166166
):
167167
super().__init__(
168168
code.value,
169-
jwt_verification_exception_description.get(
170-
code, "An exception occurred while verifying the JWT."
169+
token_verification_exception_description.get(
170+
code, "An exception occurred while verifying the token."
171171
),
172172
)
173173
self.code = code

mcpauth/middleware/create_bearer_auth.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextvars import ContextVar
12
from typing import Any, Dict, List, Optional
23
from urllib.parse import urlparse
34
import logging
@@ -9,13 +10,13 @@
910

1011
from ..exceptions import (
1112
MCPAuthBearerAuthException,
12-
MCPAuthJwtVerificationException,
13+
MCPAuthTokenVerificationException,
1314
MCPAuthAuthServerException,
1415
MCPAuthConfigException,
1516
BearerAuthExceptionCode,
1617
MCPAuthBearerAuthExceptionDetails,
1718
)
18-
from ..types import VerifyAccessTokenFunction, Record
19+
from ..types import AuthInfo, VerifyAccessTokenFunction, Record
1920

2021

2122
class BearerAuthConfig(BaseModel):
@@ -92,7 +93,7 @@ def _handle_error(
9293
Returns:
9394
A tuple of (status_code, response_body).
9495
"""
95-
if isinstance(error, MCPAuthJwtVerificationException):
96+
if isinstance(error, MCPAuthTokenVerificationException):
9697
return 401, error.to_json(show_error_details)
9798

9899
if isinstance(error, MCPAuthBearerAuthException):
@@ -114,20 +115,22 @@ def _handle_error(
114115

115116

116117
def create_bearer_auth(
117-
verify_access_token: VerifyAccessTokenFunction, config: BearerAuthConfig
118+
verify_access_token: VerifyAccessTokenFunction,
119+
config: BearerAuthConfig,
120+
context_var: ContextVar[Optional[AuthInfo]],
118121
) -> type[BaseHTTPMiddleware]:
119122
"""
120123
Creates a middleware function for handling Bearer auth.
121124
122125
This middleware extracts the Bearer token from the `Authorization` header, verifies it using the
123126
provided `verify_access_token` function, and checks the issuer, audience, and required scopes.
124127
125-
Args:
126-
verify_access_token: A function that takes a Bearer token and returns an `AuthInfo` object.
127-
config: Configuration for the Bearer auth handler.
128+
:param verify_access_token: A function that takes a Bearer token and returns an `AuthInfo` object.
129+
:param config: Configuration for the Bearer auth handler.
130+
:param context_var: Context variable to store the `AuthInfo` object for the current request.
131+
This allows access to the authenticated user's information in later middleware or route handlers.
128132
129-
Returns:
130-
A middleware class that handles Bearer auth.
133+
:return: A middleware class that handles Bearer auth.
131134
"""
132135

133136
if not callable(verify_access_token):
@@ -206,8 +209,12 @@ async def dispatch(
206209
cause=details,
207210
)
208211

209-
# Attach auth info to the request
210-
request.state.auth = auth_info
212+
if context_var.get() is not None:
213+
logging.warning(
214+
"Overwriting existing auth info in context variable."
215+
)
216+
217+
context_var.set(auth_info)
211218

212219
# Call the next middleware or route handler
213220
response = await call_next(request)

0 commit comments

Comments
 (0)