Skip to content

Commit 7e70971

Browse files
praboud-antdsp-ant
authored andcommitted
Get tests passing
1 parent f8ac479 commit 7e70971

File tree

5 files changed

+297
-70
lines changed

5 files changed

+297
-70
lines changed

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
from starlette.requests import HTTPConnection, Request
1111
from starlette.exceptions import HTTPException
12-
from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, SimpleUser, UnauthenticatedUser
12+
from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, SimpleUser, UnauthenticatedUser, has_required_scope
1313
from starlette.middleware.authentication import AuthenticationMiddleware
14+
from starlette.types import Scope
1415

1516
from mcp.server.auth.errors import InsufficientScopeError, InvalidTokenError, OAuthError
1617
from mcp.server.auth.provider import OAuthServerProvider
@@ -34,22 +35,12 @@ class BearerAuthBackend(AuthenticationBackend):
3435
def __init__(
3536
self,
3637
provider: OAuthServerProvider,
37-
required_scopes: Optional[List[str]] = None
3838
):
39-
"""
40-
Initialize the backend.
41-
42-
Args:
43-
provider: Authentication provider to validate tokens
44-
required_scopes: Optional list of scopes that the token must have
45-
"""
4639
self.provider = provider
47-
self.required_scopes = required_scopes or []
4840

4941
async def authenticate(self, conn: HTTPConnection):
5042

5143
if "Authorization" not in conn.headers:
52-
raise AuthenticationError()
5344
return None
5445

5546
auth_header = conn.headers["Authorization"]
@@ -61,14 +52,7 @@ async def authenticate(self, conn: HTTPConnection):
6152
try:
6253
# Validate the token with the provider
6354
auth_info = await self.provider.verify_access_token(token)
64-
65-
# Check if the token has all required scopes
66-
if self.required_scopes:
67-
has_all_scopes = all(scope in auth_info.scopes for scope in self.required_scopes)
68-
if not has_all_scopes:
69-
raise InsufficientScopeError("Insufficient scope")
70-
71-
# Check if the token is expired
55+
7256
if auth_info.expires_at and auth_info.expires_at < int(time.time()):
7357
raise InvalidTokenError("Token has expired")
7458

@@ -79,7 +63,7 @@ async def authenticate(self, conn: HTTPConnection):
7963
return None
8064

8165

82-
class BearerAuthMiddleware:
66+
class RequireAuthMiddleware:
8367
"""
8468
Middleware that requires a valid Bearer token in the Authorization header.
8569
@@ -92,8 +76,7 @@ class BearerAuthMiddleware:
9276
def __init__(
9377
self,
9478
app: Any,
95-
provider: OAuthServerProvider,
96-
required_scopes: Optional[List[str]] = None
79+
required_scopes: list[str]
9780
):
9881
"""
9982
Initialize the middleware.
@@ -103,18 +86,15 @@ def __init__(
10386
provider: Authentication provider to validate tokens
10487
required_scopes: Optional list of scopes that the token must have
10588
"""
106-
self.app = AuthenticationMiddleware(
107-
app,
108-
backend=BearerAuthBackend(provider, required_scopes)
109-
)
110-
111-
async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None:
112-
"""
113-
Process the request and validate the bearer token.
89+
self.app = app
90+
self.required_scopes = required_scopes
91+
92+
async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
93+
auth_credentials = scope.get('auth')
11494

115-
Args:
116-
scope: ASGI scope
117-
receive: ASGI receive function
118-
send: ASGI send function
119-
"""
95+
for required_scope in self.required_scopes:
96+
# auth_credentials should always be provided; this is just paranoia
97+
if auth_credentials is None or required_scope not in auth_credentials.scopes:
98+
raise HTTPException(status_code=403, detail="Insufficient scope")
99+
120100
await self.app(scope, receive, send)

src/mcp/server/fastmcp/server.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
from starlette.applications import Starlette
1919
from starlette.authentication import requires
2020
from starlette.middleware.authentication import AuthenticationMiddleware
21+
from sse_starlette import EventSourceResponse
2122
import uvicorn
2223
from pydantic import BaseModel, Field
2324
from pydantic.networks import AnyUrl
2425
from pydantic_settings import BaseSettings, SettingsConfigDict
2526

26-
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
27+
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
2728
from mcp.server.auth.provider import OAuthServerProvider
2829
from mcp.server.auth.router import ClientRegistrationOptions, RevocationOptions
2930
from mcp.server.auth.types import AuthInfo
@@ -487,7 +488,7 @@ def starlette_app(self) -> Starlette:
487488
# Set up auth context and dependencies
488489

489490
sse = SseServerTransport("/messages/")
490-
async def handle_sse(request):
491+
async def handle_sse(request) -> EventSourceResponse:
491492
# Add client ID from auth context into request context if available
492493
request_meta = {}
493494

@@ -499,39 +500,41 @@ async def handle_sse(request):
499500
streams[1],
500501
self._mcp_server.create_initialization_options(),
501502
)
503+
return streams[2]
502504

503505
# Create routes
504506
routes = []
505507
middleware = []
506508
required_scopes = self.settings.auth_required_scopes or []
509+
auth_router = None
507510

508511
# Add auth endpoints if auth provider is configured
509512
if self._auth_provider and self.settings.auth_issuer_url:
510513
from mcp.server.auth.router import create_auth_router
511-
if "authenticated" not in required_scopes:
512-
required_scopes.append("authenticated")
513514

514515
# Set up bearer auth middleware if auth is required
515516
middleware = [
516517
Middleware(
517518
AuthenticationMiddleware,
518519
backend=BearerAuthBackend(
519520
provider=self._auth_provider,
520-
required_scopes=self.settings.auth_required_scopes
521521
)
522522
)
523523
]
524524
auth_router = create_auth_router(
525-
self._auth_provider,
526-
self.settings.auth_issuer_url,
527-
self.settings.auth_service_documentation_url
525+
provider=self._auth_provider,
526+
issuer_url=self.settings.auth_issuer_url,
527+
service_documentation_url=self.settings.auth_service_documentation_url,
528+
client_registration_options=self.settings.auth_client_registration_options,
529+
revocation_options=self.settings.auth_revocation_options
528530
)
529531

530532
# Add the auth router as a mount
531-
routes.append(Mount("/", app=auth_router))
532533

533534
routes.append(Route("/sse", endpoint=requires(required_scopes)(handle_sse), methods=["GET"]))
534-
routes.append(Mount("/messages/", app=requires(required_scopes)(sse.handle_post_message)))
535+
routes.append(Mount("/messages/", app=RequireAuthMiddleware(sse.handle_post_message, required_scopes)))
536+
if auth_router:
537+
routes.append(Mount("/", app=auth_router))
535538

536539
# Create Starlette app with routes and middleware
537540
return Starlette(

src/mcp/server/sse.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ async def handle_sse(request):
3434
import logging
3535
from contextlib import asynccontextmanager
3636
from typing import Any
37+
from typing_extensions import deprecated
3738
from urllib.parse import quote
3839
from uuid import UUID, uuid4
3940

@@ -44,6 +45,7 @@ async def handle_sse(request):
4445
from starlette.requests import Request
4546
from starlette.responses import Response
4647
from starlette.types import Receive, Scope, Send
48+
from sse_starlette import EventSourceResponse
4749

4850
import mcp.types as types
4951

@@ -78,6 +80,7 @@ def __init__(self, endpoint: str) -> None:
7880
self._read_stream_writers = {}
7981
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
8082

83+
@deprecated("use connect_sse_v2 instead")
8184
@asynccontextmanager
8285
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
8386
if scope["type"] != "http":
@@ -128,7 +131,11 @@ async def sse_writer():
128131
tg.start_soon(response, scope, receive, send)
129132

130133
logger.debug("Yielding read and write streams")
131-
yield (read_stream, write_stream)
134+
# TODO: hold on; shouldn't we be returning the EventSourceResponse?
135+
# I think this is why the tests hang
136+
# TODO: we probably shouldn't return response here, since it's a breaking change
137+
# this is just to test
138+
yield (read_stream, write_stream, response)
132139

133140
async def handle_post_message(
134141
self, scope: Scope, receive: Receive, send: Send

0 commit comments

Comments
 (0)