Skip to content

Commit 87571d8

Browse files
committed
Return 401 on missing auth, not 403
1 parent 10e00e7 commit 87571d8

File tree

4 files changed

+381
-9
lines changed

4 files changed

+381
-9
lines changed

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def __init__(self, app: Any, required_scopes: list[str]):
7373
self.required_scopes = required_scopes
7474

7575
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
76+
auth_user = scope.get("user")
77+
if not isinstance(auth_user, AuthenticatedUser):
78+
raise HTTPException(status_code=401, detail="Unauthorized")
7679
auth_credentials = scope.get("auth")
7780

7881
for required_scope in self.required_scopes:

src/mcp/server/fastmcp/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from starlette.middleware.authentication import AuthenticationMiddleware
2727
from starlette.requests import Request
2828
from starlette.responses import Response
29-
from starlette.routing import Mount, Route
29+
from starlette.routing import Mount, Route, request_response
3030

3131
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
3232
from mcp.server.auth.middleware.bearer_auth import (
@@ -586,7 +586,7 @@ async def handle_sse(request: Request) -> EventSourceResponse:
586586
routes.append(
587587
Route(
588588
self.settings.sse_path,
589-
endpoint=requires(required_scopes)(handle_sse),
589+
endpoint=RequireAuthMiddleware(request_response(handle_sse), required_scopes),
590590
methods=["GET"],
591591
)
592592
)
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
"""
2+
Tests for the BearerAuth middleware components.
3+
"""
4+
5+
import time
6+
from typing import Any, Dict, List, Optional, cast
7+
8+
import pytest
9+
from starlette.authentication import AuthCredentials
10+
from starlette.exceptions import HTTPException
11+
from starlette.requests import Request
12+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
13+
14+
from mcp.server.auth.middleware.bearer_auth import (
15+
AuthenticatedUser,
16+
BearerAuthBackend,
17+
RequireAuthMiddleware,
18+
)
19+
from mcp.server.auth.provider import (
20+
AccessToken,
21+
OAuthServerProvider,
22+
)
23+
24+
25+
class MockOAuthProvider:
26+
"""Mock OAuth provider for testing.
27+
28+
This is a simplified version that only implements the methods needed for testing
29+
the BearerAuthMiddleware components.
30+
"""
31+
32+
def __init__(self):
33+
self.tokens = {} # token -> AccessToken
34+
35+
def add_token(self, token: str, access_token: AccessToken) -> None:
36+
"""Add a token to the provider."""
37+
self.tokens[token] = access_token
38+
39+
async def load_access_token(self, token: str) -> Optional[AccessToken]:
40+
"""Load an access token."""
41+
return self.tokens.get(token)
42+
43+
44+
def add_token_to_provider(provider: OAuthServerProvider[Any, Any, Any], token: str, access_token: AccessToken) -> None:
45+
"""Helper function to add a token to a provider.
46+
47+
This is used to work around type checking issues with our mock provider.
48+
"""
49+
# We know this is actually a MockOAuthProvider
50+
mock_provider = cast(MockOAuthProvider, provider)
51+
mock_provider.add_token(token, access_token)
52+
53+
54+
class MockApp:
55+
"""Mock ASGI app for testing."""
56+
57+
def __init__(self):
58+
self.called = False
59+
self.scope: Optional[Scope] = None
60+
self.receive: Optional[Receive] = None
61+
self.send: Optional[Send] = None
62+
63+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
64+
self.called = True
65+
self.scope = scope
66+
self.receive = receive
67+
self.send = send
68+
69+
70+
@pytest.fixture
71+
def mock_oauth_provider() -> OAuthServerProvider[Any, Any, Any]:
72+
"""Create a mock OAuth provider."""
73+
# Use type casting to satisfy the type checker
74+
return cast(OAuthServerProvider[Any, Any, Any], MockOAuthProvider())
75+
76+
77+
@pytest.fixture
78+
def valid_access_token() -> AccessToken:
79+
"""Create a valid access token."""
80+
return AccessToken(
81+
token="valid_token",
82+
client_id="test_client",
83+
scopes=["read", "write"],
84+
expires_at=int(time.time()) + 3600, # 1 hour from now
85+
)
86+
87+
88+
@pytest.fixture
89+
def expired_access_token() -> AccessToken:
90+
"""Create an expired access token."""
91+
return AccessToken(
92+
token="expired_token",
93+
client_id="test_client",
94+
scopes=["read"],
95+
expires_at=int(time.time()) - 3600, # 1 hour ago
96+
)
97+
98+
99+
@pytest.fixture
100+
def no_expiry_access_token() -> AccessToken:
101+
"""Create an access token with no expiry."""
102+
return AccessToken(
103+
token="no_expiry_token",
104+
client_id="test_client",
105+
scopes=["read", "write"],
106+
expires_at=None,
107+
)
108+
109+
110+
@pytest.mark.anyio
111+
class TestBearerAuthBackend:
112+
"""Tests for the BearerAuthBackend class."""
113+
114+
async def test_no_auth_header(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]):
115+
"""Test authentication with no Authorization header."""
116+
backend = BearerAuthBackend(provider=mock_oauth_provider)
117+
request = Request({"type": "http", "headers": []})
118+
result = await backend.authenticate(request)
119+
assert result is None
120+
121+
async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]):
122+
"""Test authentication with non-Bearer Authorization header."""
123+
backend = BearerAuthBackend(provider=mock_oauth_provider)
124+
request = Request(
125+
{
126+
"type": "http",
127+
"headers": [(b"authorization", b"Basic dXNlcjpwYXNz")],
128+
}
129+
)
130+
result = await backend.authenticate(request)
131+
assert result is None
132+
133+
async def test_invalid_token(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]):
134+
"""Test authentication with invalid token."""
135+
backend = BearerAuthBackend(provider=mock_oauth_provider)
136+
request = Request(
137+
{
138+
"type": "http",
139+
"headers": [(b"authorization", b"Bearer invalid_token")],
140+
}
141+
)
142+
result = await backend.authenticate(request)
143+
assert result is None
144+
145+
async def test_expired_token(
146+
self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], expired_access_token: AccessToken
147+
):
148+
"""Test authentication with expired token."""
149+
backend = BearerAuthBackend(provider=mock_oauth_provider)
150+
add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token)
151+
request = Request(
152+
{
153+
"type": "http",
154+
"headers": [(b"authorization", b"Bearer expired_token")],
155+
}
156+
)
157+
result = await backend.authenticate(request)
158+
assert result is None
159+
160+
async def test_valid_token(
161+
self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], valid_access_token: AccessToken
162+
):
163+
"""Test authentication with valid token."""
164+
backend = BearerAuthBackend(provider=mock_oauth_provider)
165+
add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token)
166+
request = Request(
167+
{
168+
"type": "http",
169+
"headers": [(b"authorization", b"Bearer valid_token")],
170+
}
171+
)
172+
result = await backend.authenticate(request)
173+
assert result is not None
174+
credentials, user = result
175+
assert isinstance(credentials, AuthCredentials)
176+
assert isinstance(user, AuthenticatedUser)
177+
assert credentials.scopes == ["read", "write"]
178+
assert user.display_name == "test_client"
179+
assert user.access_token == valid_access_token
180+
assert user.scopes == ["read", "write"]
181+
182+
async def test_token_without_expiry(
183+
self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], no_expiry_access_token: AccessToken
184+
):
185+
"""Test authentication with token that has no expiry."""
186+
backend = BearerAuthBackend(provider=mock_oauth_provider)
187+
add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token)
188+
request = Request(
189+
{
190+
"type": "http",
191+
"headers": [(b"authorization", b"Bearer no_expiry_token")],
192+
}
193+
)
194+
result = await backend.authenticate(request)
195+
assert result is not None
196+
credentials, user = result
197+
assert isinstance(credentials, AuthCredentials)
198+
assert isinstance(user, AuthenticatedUser)
199+
assert credentials.scopes == ["read", "write"]
200+
assert user.display_name == "test_client"
201+
assert user.access_token == no_expiry_access_token
202+
assert user.scopes == ["read", "write"]
203+
204+
205+
@pytest.mark.anyio
206+
class TestRequireAuthMiddleware:
207+
"""Tests for the RequireAuthMiddleware class."""
208+
209+
async def test_no_user(self):
210+
"""Test middleware with no user in scope."""
211+
app = MockApp()
212+
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
213+
scope: Scope = {"type": "http"}
214+
215+
# Create dummy async functions for receive and send
216+
async def receive() -> Message:
217+
return {"type": "http.request"}
218+
219+
async def send(message: Message) -> None:
220+
pass
221+
222+
with pytest.raises(HTTPException) as excinfo:
223+
await middleware(scope, receive, send)
224+
225+
assert excinfo.value.status_code == 401
226+
assert excinfo.value.detail == "Unauthorized"
227+
assert not app.called
228+
229+
async def test_non_authenticated_user(self):
230+
"""Test middleware with non-authenticated user in scope."""
231+
app = MockApp()
232+
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
233+
scope: Scope = {"type": "http", "user": object()}
234+
235+
# Create dummy async functions for receive and send
236+
async def receive() -> Message:
237+
return {"type": "http.request"}
238+
239+
async def send(message: Message) -> None:
240+
pass
241+
242+
with pytest.raises(HTTPException) as excinfo:
243+
await middleware(scope, receive, send)
244+
245+
assert excinfo.value.status_code == 401
246+
assert excinfo.value.detail == "Unauthorized"
247+
assert not app.called
248+
249+
async def test_missing_required_scope(self, valid_access_token: AccessToken):
250+
"""Test middleware with user missing required scope."""
251+
app = MockApp()
252+
middleware = RequireAuthMiddleware(app, required_scopes=["admin"])
253+
254+
# Create a user with read/write scopes but not admin
255+
user = AuthenticatedUser(valid_access_token)
256+
auth = AuthCredentials(["read", "write"])
257+
258+
scope: Scope = {"type": "http", "user": user, "auth": auth}
259+
260+
# Create dummy async functions for receive and send
261+
async def receive() -> Message:
262+
return {"type": "http.request"}
263+
264+
async def send(message: Message) -> None:
265+
pass
266+
267+
with pytest.raises(HTTPException) as excinfo:
268+
await middleware(scope, receive, send)
269+
270+
assert excinfo.value.status_code == 403
271+
assert excinfo.value.detail == "Insufficient scope"
272+
assert not app.called
273+
274+
async def test_no_auth_credentials(self, valid_access_token: AccessToken):
275+
"""Test middleware with no auth credentials in scope."""
276+
app = MockApp()
277+
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
278+
279+
# Create a user with read/write scopes
280+
user = AuthenticatedUser(valid_access_token)
281+
282+
scope: Scope = {"type": "http", "user": user} # No auth credentials
283+
284+
# Create dummy async functions for receive and send
285+
async def receive() -> Message:
286+
return {"type": "http.request"}
287+
288+
async def send(message: Message) -> None:
289+
pass
290+
291+
with pytest.raises(HTTPException) as excinfo:
292+
await middleware(scope, receive, send)
293+
294+
assert excinfo.value.status_code == 403
295+
assert excinfo.value.detail == "Insufficient scope"
296+
assert not app.called
297+
298+
async def test_has_required_scopes(self, valid_access_token: AccessToken):
299+
"""Test middleware with user having all required scopes."""
300+
app = MockApp()
301+
middleware = RequireAuthMiddleware(app, required_scopes=["read"])
302+
303+
# Create a user with read/write scopes
304+
user = AuthenticatedUser(valid_access_token)
305+
auth = AuthCredentials(["read", "write"])
306+
307+
scope: Scope = {"type": "http", "user": user, "auth": auth}
308+
309+
# Create dummy async functions for receive and send
310+
async def receive() -> Message:
311+
return {"type": "http.request"}
312+
313+
async def send(message: Message) -> None:
314+
pass
315+
316+
await middleware(scope, receive, send)
317+
318+
assert app.called
319+
assert app.scope == scope
320+
assert app.receive == receive
321+
assert app.send == send
322+
323+
async def test_multiple_required_scopes(self, valid_access_token: AccessToken):
324+
"""Test middleware with multiple required scopes."""
325+
app = MockApp()
326+
middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"])
327+
328+
# Create a user with read/write scopes
329+
user = AuthenticatedUser(valid_access_token)
330+
auth = AuthCredentials(["read", "write"])
331+
332+
scope: Scope = {"type": "http", "user": user, "auth": auth}
333+
334+
# Create dummy async functions for receive and send
335+
async def receive() -> Message:
336+
return {"type": "http.request"}
337+
338+
async def send(message: Message) -> None:
339+
pass
340+
341+
await middleware(scope, receive, send)
342+
343+
assert app.called
344+
assert app.scope == scope
345+
assert app.receive == receive
346+
assert app.send == send
347+
348+
async def test_no_required_scopes(self, valid_access_token: AccessToken):
349+
"""Test middleware with no required scopes."""
350+
app = MockApp()
351+
middleware = RequireAuthMiddleware(app, required_scopes=[])
352+
353+
# Create a user with read/write scopes
354+
user = AuthenticatedUser(valid_access_token)
355+
auth = AuthCredentials(["read", "write"])
356+
357+
scope: Scope = {"type": "http", "user": user, "auth": auth}
358+
359+
# Create dummy async functions for receive and send
360+
async def receive() -> Message:
361+
return {"type": "http.request"}
362+
363+
async def send(message: Message) -> None:
364+
pass
365+
366+
await middleware(scope, receive, send)
367+
368+
assert app.called
369+
assert app.scope == scope
370+
assert app.receive == receive
371+
assert app.send == send

0 commit comments

Comments
 (0)