33"""
44
55import time
6- from typing import Any , Dict , List , Optional , cast
6+ from typing import Any , cast
77
88import pytest
99from starlette .authentication import AuthCredentials
1010from starlette .exceptions import HTTPException
1111from starlette .requests import Request
12- from starlette .types import ASGIApp , Message , Receive , Scope , Send
12+ from starlette .types import Message , Receive , Scope , Send
1313
1414from mcp .server .auth .middleware .bearer_auth import (
1515 AuthenticatedUser ,
2424
2525class MockOAuthProvider :
2626 """Mock OAuth provider for testing.
27-
27+
2828 This is a simplified version that only implements the methods needed for testing
2929 the BearerAuthMiddleware components.
3030 """
@@ -36,14 +36,16 @@ def add_token(self, token: str, access_token: AccessToken) -> None:
3636 """Add a token to the provider."""
3737 self .tokens [token ] = access_token
3838
39- async def load_access_token (self , token : str ) -> Optional [ AccessToken ] :
39+ async def load_access_token (self , token : str ) -> AccessToken | None :
4040 """Load an access token."""
4141 return self .tokens .get (token )
4242
4343
44- def add_token_to_provider (provider : OAuthServerProvider [Any , Any , Any ], token : str , access_token : AccessToken ) -> None :
44+ def add_token_to_provider (
45+ provider : OAuthServerProvider [Any , Any , Any ], token : str , access_token : AccessToken
46+ ) -> None :
4547 """Helper function to add a token to a provider.
46-
48+
4749 This is used to work around type checking issues with our mock provider.
4850 """
4951 # We know this is actually a MockOAuthProvider
@@ -56,9 +58,9 @@ class MockApp:
5658
5759 def __init__ (self ):
5860 self .called = False
59- self .scope : Optional [ Scope ] = None
60- self .receive : Optional [ Receive ] = None
61- self .send : Optional [ Send ] = None
61+ self .scope : Scope | None = None
62+ self .receive : Receive | None = None
63+ self .send : Send | None = None
6264
6365 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
6466 self .called = True
@@ -111,14 +113,18 @@ def no_expiry_access_token() -> AccessToken:
111113class TestBearerAuthBackend :
112114 """Tests for the BearerAuthBackend class."""
113115
114- async def test_no_auth_header (self , mock_oauth_provider : OAuthServerProvider [Any , Any , Any ]):
116+ async def test_no_auth_header (
117+ self , mock_oauth_provider : OAuthServerProvider [Any , Any , Any ]
118+ ):
115119 """Test authentication with no Authorization header."""
116120 backend = BearerAuthBackend (provider = mock_oauth_provider )
117121 request = Request ({"type" : "http" , "headers" : []})
118122 result = await backend .authenticate (request )
119123 assert result is None
120124
121- async def test_non_bearer_auth_header (self , mock_oauth_provider : OAuthServerProvider [Any , Any , Any ]):
125+ async def test_non_bearer_auth_header (
126+ self , mock_oauth_provider : OAuthServerProvider [Any , Any , Any ]
127+ ):
122128 """Test authentication with non-Bearer Authorization header."""
123129 backend = BearerAuthBackend (provider = mock_oauth_provider )
124130 request = Request (
@@ -130,7 +136,9 @@ async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthServerProv
130136 result = await backend .authenticate (request )
131137 assert result is None
132138
133- async def test_invalid_token (self , mock_oauth_provider : OAuthServerProvider [Any , Any , Any ]):
139+ async def test_invalid_token (
140+ self , mock_oauth_provider : OAuthServerProvider [Any , Any , Any ]
141+ ):
134142 """Test authentication with invalid token."""
135143 backend = BearerAuthBackend (provider = mock_oauth_provider )
136144 request = Request (
@@ -143,11 +151,15 @@ async def test_invalid_token(self, mock_oauth_provider: OAuthServerProvider[Any,
143151 assert result is None
144152
145153 async def test_expired_token (
146- self , mock_oauth_provider : OAuthServerProvider [Any , Any , Any ], expired_access_token : AccessToken
154+ self ,
155+ mock_oauth_provider : OAuthServerProvider [Any , Any , Any ],
156+ expired_access_token : AccessToken ,
147157 ):
148158 """Test authentication with expired token."""
149159 backend = BearerAuthBackend (provider = mock_oauth_provider )
150- add_token_to_provider (mock_oauth_provider , "expired_token" , expired_access_token )
160+ add_token_to_provider (
161+ mock_oauth_provider , "expired_token" , expired_access_token
162+ )
151163 request = Request (
152164 {
153165 "type" : "http" ,
@@ -158,7 +170,9 @@ async def test_expired_token(
158170 assert result is None
159171
160172 async def test_valid_token (
161- self , mock_oauth_provider : OAuthServerProvider [Any , Any , Any ], valid_access_token : AccessToken
173+ self ,
174+ mock_oauth_provider : OAuthServerProvider [Any , Any , Any ],
175+ valid_access_token : AccessToken ,
162176 ):
163177 """Test authentication with valid token."""
164178 backend = BearerAuthBackend (provider = mock_oauth_provider )
@@ -180,11 +194,15 @@ async def test_valid_token(
180194 assert user .scopes == ["read" , "write" ]
181195
182196 async def test_token_without_expiry (
183- self , mock_oauth_provider : OAuthServerProvider [Any , Any , Any ], no_expiry_access_token : AccessToken
197+ self ,
198+ mock_oauth_provider : OAuthServerProvider [Any , Any , Any ],
199+ no_expiry_access_token : AccessToken ,
184200 ):
185201 """Test authentication with token that has no expiry."""
186202 backend = BearerAuthBackend (provider = mock_oauth_provider )
187- add_token_to_provider (mock_oauth_provider , "no_expiry_token" , no_expiry_access_token )
203+ add_token_to_provider (
204+ mock_oauth_provider , "no_expiry_token" , no_expiry_access_token
205+ )
188206 request = Request (
189207 {
190208 "type" : "http" ,
@@ -211,17 +229,17 @@ async def test_no_user(self):
211229 app = MockApp ()
212230 middleware = RequireAuthMiddleware (app , required_scopes = ["read" ])
213231 scope : Scope = {"type" : "http" }
214-
232+
215233 # Create dummy async functions for receive and send
216234 async def receive () -> Message :
217235 return {"type" : "http.request" }
218-
236+
219237 async def send (message : Message ) -> None :
220238 pass
221-
239+
222240 with pytest .raises (HTTPException ) as excinfo :
223241 await middleware (scope , receive , send )
224-
242+
225243 assert excinfo .value .status_code == 401
226244 assert excinfo .value .detail == "Unauthorized"
227245 assert not app .called
@@ -231,17 +249,17 @@ async def test_non_authenticated_user(self):
231249 app = MockApp ()
232250 middleware = RequireAuthMiddleware (app , required_scopes = ["read" ])
233251 scope : Scope = {"type" : "http" , "user" : object ()}
234-
252+
235253 # Create dummy async functions for receive and send
236254 async def receive () -> Message :
237255 return {"type" : "http.request" }
238-
256+
239257 async def send (message : Message ) -> None :
240258 pass
241-
259+
242260 with pytest .raises (HTTPException ) as excinfo :
243261 await middleware (scope , receive , send )
244-
262+
245263 assert excinfo .value .status_code == 401
246264 assert excinfo .value .detail == "Unauthorized"
247265 assert not app .called
@@ -250,23 +268,23 @@ async def test_missing_required_scope(self, valid_access_token: AccessToken):
250268 """Test middleware with user missing required scope."""
251269 app = MockApp ()
252270 middleware = RequireAuthMiddleware (app , required_scopes = ["admin" ])
253-
271+
254272 # Create a user with read/write scopes but not admin
255273 user = AuthenticatedUser (valid_access_token )
256274 auth = AuthCredentials (["read" , "write" ])
257-
275+
258276 scope : Scope = {"type" : "http" , "user" : user , "auth" : auth }
259-
277+
260278 # Create dummy async functions for receive and send
261279 async def receive () -> Message :
262280 return {"type" : "http.request" }
263-
281+
264282 async def send (message : Message ) -> None :
265283 pass
266-
284+
267285 with pytest .raises (HTTPException ) as excinfo :
268286 await middleware (scope , receive , send )
269-
287+
270288 assert excinfo .value .status_code == 403
271289 assert excinfo .value .detail == "Insufficient scope"
272290 assert not app .called
@@ -275,22 +293,22 @@ async def test_no_auth_credentials(self, valid_access_token: AccessToken):
275293 """Test middleware with no auth credentials in scope."""
276294 app = MockApp ()
277295 middleware = RequireAuthMiddleware (app , required_scopes = ["read" ])
278-
296+
279297 # Create a user with read/write scopes
280298 user = AuthenticatedUser (valid_access_token )
281-
299+
282300 scope : Scope = {"type" : "http" , "user" : user } # No auth credentials
283-
301+
284302 # Create dummy async functions for receive and send
285303 async def receive () -> Message :
286304 return {"type" : "http.request" }
287-
305+
288306 async def send (message : Message ) -> None :
289307 pass
290-
308+
291309 with pytest .raises (HTTPException ) as excinfo :
292310 await middleware (scope , receive , send )
293-
311+
294312 assert excinfo .value .status_code == 403
295313 assert excinfo .value .detail == "Insufficient scope"
296314 assert not app .called
@@ -299,22 +317,22 @@ async def test_has_required_scopes(self, valid_access_token: AccessToken):
299317 """Test middleware with user having all required scopes."""
300318 app = MockApp ()
301319 middleware = RequireAuthMiddleware (app , required_scopes = ["read" ])
302-
320+
303321 # Create a user with read/write scopes
304322 user = AuthenticatedUser (valid_access_token )
305323 auth = AuthCredentials (["read" , "write" ])
306-
324+
307325 scope : Scope = {"type" : "http" , "user" : user , "auth" : auth }
308-
326+
309327 # Create dummy async functions for receive and send
310328 async def receive () -> Message :
311329 return {"type" : "http.request" }
312-
330+
313331 async def send (message : Message ) -> None :
314332 pass
315-
333+
316334 await middleware (scope , receive , send )
317-
335+
318336 assert app .called
319337 assert app .scope == scope
320338 assert app .receive == receive
@@ -324,22 +342,22 @@ async def test_multiple_required_scopes(self, valid_access_token: AccessToken):
324342 """Test middleware with multiple required scopes."""
325343 app = MockApp ()
326344 middleware = RequireAuthMiddleware (app , required_scopes = ["read" , "write" ])
327-
345+
328346 # Create a user with read/write scopes
329347 user = AuthenticatedUser (valid_access_token )
330348 auth = AuthCredentials (["read" , "write" ])
331-
349+
332350 scope : Scope = {"type" : "http" , "user" : user , "auth" : auth }
333-
351+
334352 # Create dummy async functions for receive and send
335353 async def receive () -> Message :
336354 return {"type" : "http.request" }
337-
355+
338356 async def send (message : Message ) -> None :
339357 pass
340-
358+
341359 await middleware (scope , receive , send )
342-
360+
343361 assert app .called
344362 assert app .scope == scope
345363 assert app .receive == receive
@@ -349,22 +367,22 @@ async def test_no_required_scopes(self, valid_access_token: AccessToken):
349367 """Test middleware with no required scopes."""
350368 app = MockApp ()
351369 middleware = RequireAuthMiddleware (app , required_scopes = [])
352-
370+
353371 # Create a user with read/write scopes
354372 user = AuthenticatedUser (valid_access_token )
355373 auth = AuthCredentials (["read" , "write" ])
356-
374+
357375 scope : Scope = {"type" : "http" , "user" : user , "auth" : auth }
358-
376+
359377 # Create dummy async functions for receive and send
360378 async def receive () -> Message :
361379 return {"type" : "http.request" }
362-
380+
363381 async def send (message : Message ) -> None :
364382 pass
365-
383+
366384 await middleware (scope , receive , send )
367-
385+
368386 assert app .called
369387 assert app .scope == scope
370388 assert app .receive == receive
0 commit comments