Skip to content

Commit c6f991b

Browse files
committed
Convert AuthContextMiddleware to plain ASGI middleware & add tests
1 parent 87571d8 commit c6f991b

File tree

3 files changed

+84
-72
lines changed

3 files changed

+84
-72
lines changed

src/mcp/server/auth/middleware/auth_context.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import contextvars
22

3-
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
4-
from starlette.requests import Request
5-
from starlette.responses import Response
3+
from starlette.types import ASGIApp, Receive, Scope, Send
64

75
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
86
from mcp.server.auth.provider import AccessToken
@@ -25,7 +23,7 @@ def get_access_token() -> AccessToken | None:
2523
return auth_user.access_token if auth_user else None
2624

2725

28-
class AuthContextMiddleware(BaseHTTPMiddleware):
26+
class AuthContextMiddleware:
2927
"""
3028
Middleware that extracts the authenticated user from the request
3129
and sets it in a contextvar for easy access throughout the request lifecycle.
@@ -35,23 +33,18 @@ class AuthContextMiddleware(BaseHTTPMiddleware):
3533
being stored in the context.
3634
"""
3735

38-
async def dispatch(
39-
self, request: Request, call_next: RequestResponseEndpoint
40-
) -> Response:
41-
# Get the authenticated user from the request if it exists
42-
user = getattr(request, "user", None)
36+
def __init__(self, app: ASGIApp):
37+
self.app = app
4338

44-
# Only set the context var if the user is an AuthenticatedUser
39+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
40+
user = scope.get("user")
4541
if isinstance(user, AuthenticatedUser):
4642
# Set the authenticated user in the contextvar
4743
token = auth_context_var.set(user)
4844
try:
49-
# Process the request
50-
response = await call_next(request)
51-
return response
45+
await self.app(scope, receive, send)
5246
finally:
53-
# Reset the contextvar after the request is processed
5447
auth_context_var.reset(token)
5548
else:
5649
# No authenticated user, just process the request
57-
return await call_next(request)
50+
await self.app(scope, receive, send)

src/mcp/server/fastmcp/server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from pydantic_settings import BaseSettings, SettingsConfigDict
2222
from sse_starlette import EventSourceResponse
2323
from starlette.applications import Starlette
24-
from starlette.authentication import requires
2524
from starlette.middleware import Middleware
2625
from starlette.middleware.authentication import AuthenticationMiddleware
2726
from starlette.requests import Request
@@ -586,7 +585,9 @@ async def handle_sse(request: Request) -> EventSourceResponse:
586585
routes.append(
587586
Route(
588587
self.settings.sse_path,
589-
endpoint=RequireAuthMiddleware(request_response(handle_sse), required_scopes),
588+
endpoint=RequireAuthMiddleware(
589+
request_response(handle_sse), required_scopes
590+
),
590591
methods=["GET"],
591592
)
592593
)

tests/server/auth/middleware/test_bearer_auth.py

Lines changed: 73 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
"""
44

55
import time
6-
from typing import Any, Dict, List, Optional, cast
6+
from typing import Any, cast
77

88
import pytest
99
from starlette.authentication import AuthCredentials
1010
from starlette.exceptions import HTTPException
1111
from starlette.requests import Request
12-
from starlette.types import ASGIApp, Message, Receive, Scope, Send
12+
from starlette.types import Message, Receive, Scope, Send
1313

1414
from mcp.server.auth.middleware.bearer_auth import (
1515
AuthenticatedUser,
@@ -24,7 +24,7 @@
2424

2525
class MockOAuthProvider:
2626
"""Mock OAuth provider for testing.
27-
27+
2828
This is a simplified version that only implements the methods needed for testing
2929
the BearerAuthMiddleware components.
3030
"""
@@ -36,14 +36,16 @@ def add_token(self, token: str, access_token: AccessToken) -> None:
3636
"""Add a token to the provider."""
3737
self.tokens[token] = access_token
3838

39-
async def load_access_token(self, token: str) -> Optional[AccessToken]:
39+
async def load_access_token(self, token: str) -> AccessToken | None:
4040
"""Load an access token."""
4141
return self.tokens.get(token)
4242

4343

44-
def add_token_to_provider(provider: OAuthServerProvider[Any, Any, Any], token: str, access_token: AccessToken) -> None:
44+
def add_token_to_provider(
45+
provider: OAuthServerProvider[Any, Any, Any], token: str, access_token: AccessToken
46+
) -> None:
4547
"""Helper function to add a token to a provider.
46-
48+
4749
This is used to work around type checking issues with our mock provider.
4850
"""
4951
# We know this is actually a MockOAuthProvider
@@ -56,9 +58,9 @@ class MockApp:
5658

5759
def __init__(self):
5860
self.called = False
59-
self.scope: Optional[Scope] = None
60-
self.receive: Optional[Receive] = None
61-
self.send: Optional[Send] = None
61+
self.scope: Scope | None = None
62+
self.receive: Receive | None = None
63+
self.send: Send | None = None
6264

6365
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6466
self.called = True
@@ -111,14 +113,18 @@ def no_expiry_access_token() -> AccessToken:
111113
class TestBearerAuthBackend:
112114
"""Tests for the BearerAuthBackend class."""
113115

114-
async def test_no_auth_header(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]):
116+
async def test_no_auth_header(
117+
self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]
118+
):
115119
"""Test authentication with no Authorization header."""
116120
backend = BearerAuthBackend(provider=mock_oauth_provider)
117121
request = Request({"type": "http", "headers": []})
118122
result = await backend.authenticate(request)
119123
assert result is None
120124

121-
async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]):
125+
async def test_non_bearer_auth_header(
126+
self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]
127+
):
122128
"""Test authentication with non-Bearer Authorization header."""
123129
backend = BearerAuthBackend(provider=mock_oauth_provider)
124130
request = Request(
@@ -130,7 +136,9 @@ async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthServerProv
130136
result = await backend.authenticate(request)
131137
assert result is None
132138

133-
async def test_invalid_token(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]):
139+
async def test_invalid_token(
140+
self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]
141+
):
134142
"""Test authentication with invalid token."""
135143
backend = BearerAuthBackend(provider=mock_oauth_provider)
136144
request = Request(
@@ -143,11 +151,15 @@ async def test_invalid_token(self, mock_oauth_provider: OAuthServerProvider[Any,
143151
assert result is None
144152

145153
async def test_expired_token(
146-
self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], expired_access_token: AccessToken
154+
self,
155+
mock_oauth_provider: OAuthServerProvider[Any, Any, Any],
156+
expired_access_token: AccessToken,
147157
):
148158
"""Test authentication with expired token."""
149159
backend = BearerAuthBackend(provider=mock_oauth_provider)
150-
add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token)
160+
add_token_to_provider(
161+
mock_oauth_provider, "expired_token", expired_access_token
162+
)
151163
request = Request(
152164
{
153165
"type": "http",
@@ -158,7 +170,9 @@ async def test_expired_token(
158170
assert result is None
159171

160172
async def test_valid_token(
161-
self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], valid_access_token: AccessToken
173+
self,
174+
mock_oauth_provider: OAuthServerProvider[Any, Any, Any],
175+
valid_access_token: AccessToken,
162176
):
163177
"""Test authentication with valid token."""
164178
backend = BearerAuthBackend(provider=mock_oauth_provider)
@@ -180,11 +194,15 @@ async def test_valid_token(
180194
assert user.scopes == ["read", "write"]
181195

182196
async def test_token_without_expiry(
183-
self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], no_expiry_access_token: AccessToken
197+
self,
198+
mock_oauth_provider: OAuthServerProvider[Any, Any, Any],
199+
no_expiry_access_token: AccessToken,
184200
):
185201
"""Test authentication with token that has no expiry."""
186202
backend = BearerAuthBackend(provider=mock_oauth_provider)
187-
add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token)
203+
add_token_to_provider(
204+
mock_oauth_provider, "no_expiry_token", no_expiry_access_token
205+
)
188206
request = Request(
189207
{
190208
"type": "http",
@@ -211,17 +229,17 @@ async def test_no_user(self):
211229
app = MockApp()
212230
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
213231
scope: Scope = {"type": "http"}
214-
232+
215233
# Create dummy async functions for receive and send
216234
async def receive() -> Message:
217235
return {"type": "http.request"}
218-
236+
219237
async def send(message: Message) -> None:
220238
pass
221-
239+
222240
with pytest.raises(HTTPException) as excinfo:
223241
await middleware(scope, receive, send)
224-
242+
225243
assert excinfo.value.status_code == 401
226244
assert excinfo.value.detail == "Unauthorized"
227245
assert not app.called
@@ -231,17 +249,17 @@ async def test_non_authenticated_user(self):
231249
app = MockApp()
232250
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
233251
scope: Scope = {"type": "http", "user": object()}
234-
252+
235253
# Create dummy async functions for receive and send
236254
async def receive() -> Message:
237255
return {"type": "http.request"}
238-
256+
239257
async def send(message: Message) -> None:
240258
pass
241-
259+
242260
with pytest.raises(HTTPException) as excinfo:
243261
await middleware(scope, receive, send)
244-
262+
245263
assert excinfo.value.status_code == 401
246264
assert excinfo.value.detail == "Unauthorized"
247265
assert not app.called
@@ -250,23 +268,23 @@ async def test_missing_required_scope(self, valid_access_token: AccessToken):
250268
"""Test middleware with user missing required scope."""
251269
app = MockApp()
252270
middleware = RequireAuthMiddleware(app, required_scopes=["admin"])
253-
271+
254272
# Create a user with read/write scopes but not admin
255273
user = AuthenticatedUser(valid_access_token)
256274
auth = AuthCredentials(["read", "write"])
257-
275+
258276
scope: Scope = {"type": "http", "user": user, "auth": auth}
259-
277+
260278
# Create dummy async functions for receive and send
261279
async def receive() -> Message:
262280
return {"type": "http.request"}
263-
281+
264282
async def send(message: Message) -> None:
265283
pass
266-
284+
267285
with pytest.raises(HTTPException) as excinfo:
268286
await middleware(scope, receive, send)
269-
287+
270288
assert excinfo.value.status_code == 403
271289
assert excinfo.value.detail == "Insufficient scope"
272290
assert not app.called
@@ -275,22 +293,22 @@ async def test_no_auth_credentials(self, valid_access_token: AccessToken):
275293
"""Test middleware with no auth credentials in scope."""
276294
app = MockApp()
277295
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
278-
296+
279297
# Create a user with read/write scopes
280298
user = AuthenticatedUser(valid_access_token)
281-
299+
282300
scope: Scope = {"type": "http", "user": user} # No auth credentials
283-
301+
284302
# Create dummy async functions for receive and send
285303
async def receive() -> Message:
286304
return {"type": "http.request"}
287-
305+
288306
async def send(message: Message) -> None:
289307
pass
290-
308+
291309
with pytest.raises(HTTPException) as excinfo:
292310
await middleware(scope, receive, send)
293-
311+
294312
assert excinfo.value.status_code == 403
295313
assert excinfo.value.detail == "Insufficient scope"
296314
assert not app.called
@@ -299,22 +317,22 @@ async def test_has_required_scopes(self, valid_access_token: AccessToken):
299317
"""Test middleware with user having all required scopes."""
300318
app = MockApp()
301319
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
302-
320+
303321
# Create a user with read/write scopes
304322
user = AuthenticatedUser(valid_access_token)
305323
auth = AuthCredentials(["read", "write"])
306-
324+
307325
scope: Scope = {"type": "http", "user": user, "auth": auth}
308-
326+
309327
# Create dummy async functions for receive and send
310328
async def receive() -> Message:
311329
return {"type": "http.request"}
312-
330+
313331
async def send(message: Message) -> None:
314332
pass
315-
333+
316334
await middleware(scope, receive, send)
317-
335+
318336
assert app.called
319337
assert app.scope == scope
320338
assert app.receive == receive
@@ -324,22 +342,22 @@ async def test_multiple_required_scopes(self, valid_access_token: AccessToken):
324342
"""Test middleware with multiple required scopes."""
325343
app = MockApp()
326344
middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"])
327-
345+
328346
# Create a user with read/write scopes
329347
user = AuthenticatedUser(valid_access_token)
330348
auth = AuthCredentials(["read", "write"])
331-
349+
332350
scope: Scope = {"type": "http", "user": user, "auth": auth}
333-
351+
334352
# Create dummy async functions for receive and send
335353
async def receive() -> Message:
336354
return {"type": "http.request"}
337-
355+
338356
async def send(message: Message) -> None:
339357
pass
340-
358+
341359
await middleware(scope, receive, send)
342-
360+
343361
assert app.called
344362
assert app.scope == scope
345363
assert app.receive == receive
@@ -349,22 +367,22 @@ async def test_no_required_scopes(self, valid_access_token: AccessToken):
349367
"""Test middleware with no required scopes."""
350368
app = MockApp()
351369
middleware = RequireAuthMiddleware(app, required_scopes=[])
352-
370+
353371
# Create a user with read/write scopes
354372
user = AuthenticatedUser(valid_access_token)
355373
auth = AuthCredentials(["read", "write"])
356-
374+
357375
scope: Scope = {"type": "http", "user": user, "auth": auth}
358-
376+
359377
# Create dummy async functions for receive and send
360378
async def receive() -> Message:
361379
return {"type": "http.request"}
362-
380+
363381
async def send(message: Message) -> None:
364382
pass
365-
383+
366384
await middleware(scope, receive, send)
367-
385+
368386
assert app.called
369387
assert app.scope == scope
370388
assert app.receive == receive

0 commit comments

Comments
 (0)