Skip to content

Commit 3e28910

Browse files
authored
feat: create bearer auth middleware (#1)
1 parent 04b5b58 commit 3e28910

File tree

7 files changed

+978
-9
lines changed

7 files changed

+978
-9
lines changed

mcpauth/exceptions.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,15 @@ def to_json(self, show_cause: bool = False) -> Record:
3434
data: Record = {
3535
"error": self.code.value if isinstance(self.code, Enum) else self.code,
3636
"error_description": self.message,
37-
"cause": self.cause if show_cause and hasattr(self, "cause") else None,
37+
"cause": (
38+
(
39+
{k: v for k, v in self.cause.model_dump().items() if v is not None}
40+
if isinstance(self.cause, BaseModel)
41+
else str(self.cause)
42+
)
43+
if show_cause and hasattr(self, "cause")
44+
else None
45+
),
3846
}
3947
return {k: v for k, v in data.items() if v is not None}
4048

@@ -99,8 +107,8 @@ class MCPAuthBearerAuthExceptionDetails(BaseModel):
99107
cause: Any = None
100108
uri: Optional[str] = None
101109
missing_scopes: Optional[List[str]] = None
102-
expected: Optional[Union[str, Record]] = None
103-
actual: Optional[Union[str, Record]] = None
110+
expected: Any = None
111+
actual: Any = None
104112

105113

106114
class MCPAuthBearerAuthException(MCPAuthException):
@@ -124,15 +132,15 @@ def __init__(
124132

125133
def to_json(self, show_cause: bool = False) -> Dict[str, Optional[str]]:
126134
# Matches the OAuth 2.0 exception response format at best effort
127-
result = super().to_json(show_cause)
135+
data = super().to_json(show_cause)
128136
if self.cause:
129-
result.update(
137+
data.update(
130138
{
131139
"error_uri": self.cause.uri,
132140
"missing_scopes": self.cause.missing_scopes,
133141
}
134142
)
135-
return result
143+
return {k: v for k, v in data.items() if v is not None}
136144

137145

138146
class MCPAuthJwtVerificationExceptionCode(str, Enum):

mcpauth/exceptioins_test.py renamed to mcpauth/exceptions_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_to_json(self):
2828
assert mcp_exception.to_json(show_cause=True) == {
2929
"error": "test_code",
3030
"error_description": "Test message",
31-
"cause": exception,
31+
"cause": str(exception),
3232
}
3333

3434
def test_properties(self):
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
from typing import Any, Dict, List, Optional
2+
from urllib.parse import urlparse
3+
import logging
4+
from pydantic import BaseModel
5+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
6+
from starlette.requests import Request
7+
from starlette.responses import Response, JSONResponse
8+
from starlette.datastructures import Headers
9+
10+
from ..exceptions import (
11+
MCPAuthBearerAuthException,
12+
MCPAuthJwtVerificationException,
13+
MCPAuthAuthServerException,
14+
MCPAuthConfigException,
15+
BearerAuthExceptionCode,
16+
MCPAuthBearerAuthExceptionDetails,
17+
)
18+
from ..types import VerifyAccessTokenFunction, Record
19+
20+
21+
class BearerAuthConfig(BaseModel):
22+
"""
23+
Configuration for the Bearer auth handler.
24+
25+
Attributes:
26+
issuer: The expected issuer of the access token.
27+
audience: The expected audience of the access token.
28+
required_scopes: An array of required scopes that the access token must have.
29+
show_error_details: Whether to show detailed error information in the response.
30+
"""
31+
32+
issuer: str
33+
audience: Optional[str] = None
34+
required_scopes: Optional[List[str]] = None
35+
show_error_details: bool = False
36+
37+
38+
def get_bearer_token_from_headers(headers: Headers) -> str:
39+
"""
40+
Extract the Bearer token from the request headers.
41+
42+
Args:
43+
headers: The HTTP request headers.
44+
45+
Returns:
46+
The Bearer token.
47+
48+
Raises:
49+
MCPAuthBearerAuthException: If the Authorization header is missing or invalid.
50+
"""
51+
52+
auth_header = headers.get("authorization") or headers.get("Authorization")
53+
54+
print(f"Authorization header: {auth_header}")
55+
56+
if not auth_header:
57+
raise MCPAuthBearerAuthException(BearerAuthExceptionCode.MISSING_AUTH_HEADER)
58+
59+
parts = auth_header.split(" ")
60+
if len(parts) != 2 or parts[0].lower() != "bearer":
61+
raise MCPAuthBearerAuthException(
62+
BearerAuthExceptionCode.INVALID_AUTH_HEADER_FORMAT
63+
)
64+
65+
token = parts[1]
66+
if not token:
67+
raise MCPAuthBearerAuthException(BearerAuthExceptionCode.MISSING_BEARER_TOKEN)
68+
69+
return token
70+
71+
72+
def _handle_error(
73+
error: Exception, show_error_details: bool = False
74+
) -> tuple[int, Dict[str, Any]]:
75+
"""
76+
Handle errors from the Bearer auth process.
77+
78+
Args:
79+
error: The exception that was caught.
80+
show_error_details: Whether to include detailed error information in the response.
81+
82+
Returns:
83+
A tuple of (status_code, response_body).
84+
"""
85+
if isinstance(error, MCPAuthJwtVerificationException):
86+
return 401, error.to_json(show_error_details)
87+
88+
if isinstance(error, MCPAuthBearerAuthException):
89+
if error.code == BearerAuthExceptionCode.MISSING_REQUIRED_SCOPES:
90+
return 403, error.to_json(show_error_details)
91+
return 401, error.to_json(show_error_details)
92+
93+
if isinstance(error, (MCPAuthAuthServerException, MCPAuthConfigException)):
94+
response: Record = {
95+
"error": "server_error",
96+
"error_description": "An error occurred with the authorization server.",
97+
}
98+
if show_error_details:
99+
response["cause"] = error.to_json()
100+
return 500, response
101+
102+
# Re-raise other errors
103+
raise error
104+
105+
106+
def create_bearer_auth(
107+
verify_access_token: VerifyAccessTokenFunction, config: BearerAuthConfig
108+
) -> type[BaseHTTPMiddleware]:
109+
"""
110+
Creates a middleware function for handling Bearer auth.
111+
112+
This middleware extracts the Bearer token from the `Authorization` header, verifies it using the
113+
provided `verify_access_token` function, and checks the issuer, audience, and required scopes.
114+
115+
Args:
116+
verify_access_token: A function that takes a Bearer token and returns an `AuthInfo` object.
117+
config: Configuration for the Bearer auth handler.
118+
119+
Returns:
120+
A middleware class that handles Bearer auth.
121+
"""
122+
123+
if not callable(verify_access_token):
124+
raise TypeError(
125+
"`verify_access_token` must be a function that takes a token and returns an `AuthInfo` object."
126+
)
127+
128+
try:
129+
result = urlparse(config.issuer)
130+
if not all([result.scheme, result.netloc]):
131+
raise ValueError("Invalid URL")
132+
except:
133+
raise TypeError("`issuer` must be a valid URL.")
134+
135+
class BearerAuthMiddleware(BaseHTTPMiddleware):
136+
"""
137+
Middleware class that handles Bearer auth.
138+
139+
This class is used to wrap the request handling process and apply Bearer auth checks.
140+
"""
141+
142+
async def dispatch(
143+
self, request: Request, call_next: RequestResponseEndpoint
144+
) -> Response:
145+
"""
146+
Dispatch method that processes the request and applies Bearer auth checks.
147+
148+
Args:
149+
request: The HTTP request.
150+
call_next: The next middleware or route handler to call.
151+
152+
Returns:
153+
The HTTP response after processing the request.
154+
"""
155+
try:
156+
token = get_bearer_token_from_headers(request.headers)
157+
auth_info = verify_access_token(token)
158+
159+
if auth_info.issuer != config.issuer:
160+
details = MCPAuthBearerAuthExceptionDetails(
161+
expected=config.issuer, actual=auth_info.issuer
162+
)
163+
raise MCPAuthBearerAuthException(
164+
BearerAuthExceptionCode.INVALID_ISSUER, cause=details
165+
)
166+
167+
if config.audience:
168+
audience_matches = (
169+
config.audience == auth_info.audience
170+
if isinstance(auth_info.audience, str)
171+
else (
172+
isinstance(auth_info.audience, list)
173+
and config.audience in auth_info.audience
174+
)
175+
)
176+
if not audience_matches:
177+
details = MCPAuthBearerAuthExceptionDetails(
178+
expected=config.audience, actual=auth_info.audience
179+
)
180+
raise MCPAuthBearerAuthException(
181+
BearerAuthExceptionCode.INVALID_AUDIENCE, cause=details
182+
)
183+
184+
if config.required_scopes:
185+
missing_scopes = [
186+
scope
187+
for scope in config.required_scopes
188+
if scope not in auth_info.scopes
189+
]
190+
if missing_scopes:
191+
details = MCPAuthBearerAuthExceptionDetails(
192+
missing_scopes=missing_scopes
193+
)
194+
raise MCPAuthBearerAuthException(
195+
BearerAuthExceptionCode.MISSING_REQUIRED_SCOPES,
196+
cause=details,
197+
)
198+
199+
# Attach auth info to the request
200+
request.state.auth = auth_info
201+
202+
# Call the next middleware or route handler
203+
response = await call_next(request)
204+
return response
205+
206+
except Exception as error:
207+
logging.error(f"Error during Bearer auth: {error}")
208+
status_code, response_body = _handle_error(
209+
error, config.show_error_details
210+
)
211+
return JSONResponse(status_code=status_code, content=response_body)
212+
213+
return BearerAuthMiddleware

0 commit comments

Comments
 (0)