Skip to content

Commit a1161ab

Browse files
committed
Get tests passing
1 parent 00145de commit a1161ab

File tree

5 files changed

+297
-70
lines changed

5 files changed

+297
-70
lines changed

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
from starlette.requests import HTTPConnection, Request
1111
from starlette.exceptions import HTTPException
12-
from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, SimpleUser, UnauthenticatedUser
12+
from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, SimpleUser, UnauthenticatedUser, has_required_scope
1313
from starlette.middleware.authentication import AuthenticationMiddleware
14+
from starlette.types import Scope
1415

1516
from mcp.server.auth.errors import InsufficientScopeError, InvalidTokenError, OAuthError
1617
from mcp.server.auth.provider import OAuthServerProvider
@@ -34,22 +35,12 @@ class BearerAuthBackend(AuthenticationBackend):
3435
def __init__(
3536
self,
3637
provider: OAuthServerProvider,
37-
required_scopes: Optional[List[str]] = None
3838
):
39-
"""
40-
Initialize the backend.
41-
42-
Args:
43-
provider: Authentication provider to validate tokens
44-
required_scopes: Optional list of scopes that the token must have
45-
"""
4639
self.provider = provider
47-
self.required_scopes = required_scopes or []
4840

4941
async def authenticate(self, conn: HTTPConnection):
5042

5143
if "Authorization" not in conn.headers:
52-
raise AuthenticationError()
5344
return None
5445

5546
auth_header = conn.headers["Authorization"]
@@ -61,14 +52,7 @@ async def authenticate(self, conn: HTTPConnection):
6152
try:
6253
# Validate the token with the provider
6354
auth_info = await self.provider.verify_access_token(token)
64-
65-
# Check if the token has all required scopes
66-
if self.required_scopes:
67-
has_all_scopes = all(scope in auth_info.scopes for scope in self.required_scopes)
68-
if not has_all_scopes:
69-
raise InsufficientScopeError("Insufficient scope")
70-
71-
# Check if the token is expired
55+
7256
if auth_info.expires_at and auth_info.expires_at < int(time.time()):
7357
raise InvalidTokenError("Token has expired")
7458

@@ -79,7 +63,7 @@ async def authenticate(self, conn: HTTPConnection):
7963
return None
8064

8165

82-
class BearerAuthMiddleware:
66+
class RequireAuthMiddleware:
8367
"""
8468
Middleware that requires a valid Bearer token in the Authorization header.
8569
@@ -92,8 +76,7 @@ class BearerAuthMiddleware:
9276
def __init__(
9377
self,
9478
app: Any,
95-
provider: OAuthServerProvider,
96-
required_scopes: Optional[List[str]] = None
79+
required_scopes: list[str]
9780
):
9881
"""
9982
Initialize the middleware.
@@ -103,18 +86,15 @@ def __init__(
10386
provider: Authentication provider to validate tokens
10487
required_scopes: Optional list of scopes that the token must have
10588
"""
106-
self.app = AuthenticationMiddleware(
107-
app,
108-
backend=BearerAuthBackend(provider, required_scopes)
109-
)
110-
111-
async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None:
112-
"""
113-
Process the request and validate the bearer token.
89+
self.app = app
90+
self.required_scopes = required_scopes
91+
92+
async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
93+
auth_credentials = scope.get('auth')
11494

115-
Args:
116-
scope: ASGI scope
117-
receive: ASGI receive function
118-
send: ASGI send function
119-
"""
95+
for required_scope in self.required_scopes:
96+
# auth_credentials should always be provided; this is just paranoia
97+
if auth_credentials is None or required_scope not in auth_credentials.scopes:
98+
raise HTTPException(status_code=403, detail="Insufficient scope")
99+
120100
await self.app(scope, receive, send)

src/mcp/server/fastmcp/server.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
from starlette.applications import Starlette
1717
from starlette.authentication import requires
1818
from starlette.middleware.authentication import AuthenticationMiddleware
19+
from sse_starlette import EventSourceResponse
1920
import uvicorn
2021
from pydantic import BaseModel, Field
2122
from pydantic.networks import AnyUrl
2223
from pydantic_settings import BaseSettings, SettingsConfigDict
2324

24-
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
25+
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
2526
from mcp.server.auth.provider import OAuthServerProvider
2627
from mcp.server.auth.router import ClientRegistrationOptions, RevocationOptions
2728
from mcp.server.auth.types import AuthInfo
@@ -501,7 +502,7 @@ def starlette_app(self) -> Starlette:
501502
# Set up auth context and dependencies
502503

503504
sse = SseServerTransport("/messages/")
504-
async def handle_sse(request):
505+
async def handle_sse(request) -> EventSourceResponse:
505506
# Add client ID from auth context into request context if available
506507
request_meta = {}
507508

@@ -513,39 +514,41 @@ async def handle_sse(request):
513514
streams[1],
514515
self._mcp_server.create_initialization_options(),
515516
)
517+
return streams[2]
516518

517519
# Create routes
518520
routes = []
519521
middleware = []
520522
required_scopes = self.settings.auth_required_scopes or []
523+
auth_router = None
521524

522525
# Add auth endpoints if auth provider is configured
523526
if self._auth_provider and self.settings.auth_issuer_url:
524527
from mcp.server.auth.router import create_auth_router
525-
if "authenticated" not in required_scopes:
526-
required_scopes.append("authenticated")
527528

528529
# Set up bearer auth middleware if auth is required
529530
middleware = [
530531
Middleware(
531532
AuthenticationMiddleware,
532533
backend=BearerAuthBackend(
533534
provider=self._auth_provider,
534-
required_scopes=self.settings.auth_required_scopes
535535
)
536536
)
537537
]
538538
auth_router = create_auth_router(
539-
self._auth_provider,
540-
self.settings.auth_issuer_url,
541-
self.settings.auth_service_documentation_url
539+
provider=self._auth_provider,
540+
issuer_url=self.settings.auth_issuer_url,
541+
service_documentation_url=self.settings.auth_service_documentation_url,
542+
client_registration_options=self.settings.auth_client_registration_options,
543+
revocation_options=self.settings.auth_revocation_options
542544
)
543545

544546
# Add the auth router as a mount
545-
routes.append(Mount("/", app=auth_router))
546547

547548
routes.append(Route("/sse", endpoint=requires(required_scopes)(handle_sse), methods=["GET"]))
548-
routes.append(Mount("/messages/", app=requires(required_scopes)(sse.handle_post_message)))
549+
routes.append(Mount("/messages/", app=RequireAuthMiddleware(sse.handle_post_message, required_scopes)))
550+
if auth_router:
551+
routes.append(Mount("/", app=auth_router))
549552

550553
# Create Starlette app with routes and middleware
551554
return Starlette(

src/mcp/server/sse.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ async def handle_sse(request):
3434
import logging
3535
from contextlib import asynccontextmanager
3636
from typing import Any
37+
from typing_extensions import deprecated
3738
from urllib.parse import quote
3839
from uuid import UUID, uuid4
3940

@@ -44,6 +45,7 @@ async def handle_sse(request):
4445
from starlette.requests import Request
4546
from starlette.responses import Response
4647
from starlette.types import Receive, Scope, Send
48+
from sse_starlette import EventSourceResponse
4749

4850
import mcp.types as types
4951

@@ -78,6 +80,7 @@ def __init__(self, endpoint: str) -> None:
7880
self._read_stream_writers = {}
7981
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
8082

83+
@deprecated("use connect_sse_v2 instead")
8184
@asynccontextmanager
8285
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
8386
if scope["type"] != "http":
@@ -128,7 +131,11 @@ async def sse_writer():
128131
tg.start_soon(response, scope, receive, send)
129132

130133
logger.debug("Yielding read and write streams")
131-
yield (read_stream, write_stream)
134+
# TODO: hold on; shouldn't we be returning the EventSourceResponse?
135+
# I think this is why the tests hang
136+
# TODO: we probably shouldn't return response here, since it's a breaking change
137+
# this is just to test
138+
yield (read_stream, write_stream, response)
132139

133140
async def handle_post_message(
134141
self, scope: Scope, receive: Receive, send: Send

0 commit comments

Comments
 (0)