Skip to content

Commit 97b1b55

Browse files
committed
refactor: reorg test files
1 parent 4be8dbc commit 97b1b55

File tree

9 files changed

+205
-9
lines changed

9 files changed

+205
-9
lines changed

β€Žmcpauth/__init__.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import logging
2+
from typing import Any, Literal, Union
23

4+
from .middleware.create_bearer_auth import BaseBearerAuthConfig, BearerAuthConfig
5+
from .types import VerifyAccessTokenFunction
36
from .utils.fetch_server_config import ServerMetadataPaths
47
from .config import MCPAuthConfig
58
from .exceptions import MCPAuthAuthServerException, AuthServerExceptionCode
@@ -14,6 +17,10 @@ def __init__(self, config: MCPAuthConfig):
1417
result = validate_server_config(config.server)
1518

1619
if not result.is_valid:
20+
logging.error(
21+
"The authorization server configuration is invalid:\n"
22+
f"{result.errors}\n"
23+
)
1724
raise MCPAuthAuthServerException(
1825
AuthServerExceptionCode.INVALID_SERVER_CONFIG, cause=result
1926
)
@@ -55,3 +62,47 @@ async def dispatch(
5562
return await call_next(request)
5663

5764
return DelegatedMiddleware
65+
66+
def bearer_auth_middleware(
67+
self,
68+
mode_or_verify: Union[Literal["jwt"], VerifyAccessTokenFunction],
69+
config: BaseBearerAuthConfig,
70+
jwt_options: dict[str, Any] = {},
71+
) -> type[BaseHTTPMiddleware]:
72+
"""
73+
Creates a middleware that handles bearer token authentication.
74+
75+
:param mode_or_verify: If "jwt", uses built-in JWT verification; or a custom function that
76+
takes a string token and returns an `AuthInfo` object.
77+
:param config: Configuration for the Bearer auth handler, including audience, required
78+
scopes, etc.
79+
:param jwt_options: Optional dictionary of additional options for JWT verification
80+
(`jwt.decode`). Not used if a custom function is provided.
81+
:return: A middleware class that can be used in a Starlette or FastAPI application.
82+
"""
83+
84+
metadata = self.config.server.metadata
85+
if isinstance(mode_or_verify, str) and mode_or_verify == "jwt":
86+
from .utils.create_verify_jwt import create_verify_jwt
87+
88+
if not metadata.jwks_uri:
89+
raise MCPAuthAuthServerException(
90+
AuthServerExceptionCode.MISSING_JWKS_URI
91+
)
92+
93+
verify = create_verify_jwt(
94+
metadata.jwks_uri,
95+
options=jwt_options,
96+
)
97+
elif callable(mode_or_verify):
98+
verify = mode_or_verify
99+
else:
100+
raise ValueError(
101+
"mode_or_verify must be 'jwt' or a callable function that verifies tokens."
102+
)
103+
104+
from .middleware.create_bearer_auth import create_bearer_auth
105+
106+
return create_bearer_auth(
107+
verify, BearerAuthConfig(issuer=metadata.issuer, **config.model_dump())
108+
)

β€Žmcpauth/middleware/create_bearer_auth.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,39 @@
1818
from ..types import VerifyAccessTokenFunction, Record
1919

2020

21-
class BearerAuthConfig(BaseModel):
21+
class BaseBearerAuthConfig(BaseModel):
2222
"""
23-
Configuration for the Bearer auth handler.
24-
25-
Attributes:
26-
issuer: The expected issuer of the access token.
27-
audience: The expected audience of the access token.
28-
required_scopes: An array of required scopes that the access token must have.
29-
show_error_details: Whether to show detailed error information in the response.
23+
Base configuration for the Bearer auth handler.
3024
"""
3125

32-
issuer: str
3326
audience: Optional[str] = None
27+
"""
28+
The expected audience of the access token. If not provided, no audience check is performed.
29+
"""
30+
3431
required_scopes: Optional[List[str]] = None
32+
"""
33+
An array of required scopes that the access token must have. If not provided, no scope check is
34+
performed.
35+
"""
36+
3537
show_error_details: bool = False
38+
"""
39+
Whether to show detailed error information in the response. Defaults to False.
40+
If True, detailed error information will be included in the response body for debugging
41+
purposes.
42+
"""
43+
44+
45+
class BearerAuthConfig(BaseBearerAuthConfig):
46+
"""
47+
Configuration for the Bearer auth handler.
48+
"""
49+
50+
issuer: str
51+
"""
52+
The expected issuer of the access token. This should be a valid URL.
53+
"""
3654

3755

3856
def get_bearer_token_from_headers(headers: Headers) -> str:

β€Žpytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[pytest]
2+
pythonpath = .

β€Žmcpauth/__init__test.py renamed to β€Žtests/__init__test.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from mcpauth.config import MCPAuthConfig
77
from mcpauth.models.auth_server import AuthServerConfig, AuthServerType
88
from mcpauth.models.oauth import AuthorizationServerMetadata
9+
from mcpauth.middleware.create_bearer_auth import BaseBearerAuthConfig
10+
from mcpauth.middleware.create_bearer_auth import BaseBearerAuthConfig
11+
from mcpauth.middleware.create_bearer_auth import BaseBearerAuthConfig
12+
from mcpauth.middleware.create_bearer_auth import BaseBearerAuthConfig
913

1014

1115
class TestMCPAuth:
@@ -72,6 +76,8 @@ def test_init_with_warnings(self, mock_warning: MagicMock):
7276
# Verify
7377
assert mock_warning.called
7478

79+
80+
class TestDelegatedMiddleware:
7581
@pytest.mark.asyncio
7682
async def test_delegated_middleware_oauth_endpoint(self):
7783
# Setup
@@ -135,3 +141,122 @@ async def test_delegated_middleware_other_endpoint(self):
135141
# Verify
136142
assert mock_call_next.called
137143
assert response == mock_response
144+
145+
146+
class TestBearerAuthMiddleware:
147+
def test_bearer_auth_middleware_jwt_mode(self):
148+
# Setup
149+
server_config = AuthServerConfig(
150+
type=AuthServerType.OAUTH,
151+
metadata=AuthorizationServerMetadata(
152+
issuer="https://example.com",
153+
authorization_endpoint="https://example.com/oauth/authorize",
154+
token_endpoint="https://example.com/oauth/token",
155+
jwks_uri="https://example.com/.well-known/jwks.json",
156+
response_types_supported=["code"],
157+
grant_types_supported=["authorization_code"],
158+
code_challenge_methods_supported=["S256"],
159+
),
160+
)
161+
config = MCPAuthConfig(server=server_config)
162+
auth = MCPAuth(config)
163+
164+
# Exercise
165+
with patch(
166+
"mcpauth.utils.create_verify_jwt.create_verify_jwt"
167+
) as mock_create_verify_jwt:
168+
mock_create_verify_jwt.return_value = MagicMock()
169+
middleware_class = auth.bearer_auth_middleware(
170+
"jwt", BaseBearerAuthConfig(required_scopes=["profile"])
171+
)
172+
173+
# Verify
174+
assert middleware_class is not None
175+
mock_create_verify_jwt.assert_called_once_with(
176+
"https://example.com/.well-known/jwks.json", options={}
177+
)
178+
179+
def test_bearer_auth_middleware_custom_verify(self):
180+
# Setup
181+
server_config = AuthServerConfig(
182+
type=AuthServerType.OAUTH,
183+
metadata=AuthorizationServerMetadata(
184+
issuer="https://example.com",
185+
authorization_endpoint="https://example.com/oauth/authorize",
186+
token_endpoint="https://example.com/oauth/token",
187+
response_types_supported=["code"],
188+
grant_types_supported=["authorization_code"],
189+
code_challenge_methods_supported=["S256"],
190+
),
191+
)
192+
config = MCPAuthConfig(server=server_config)
193+
auth = MCPAuth(config)
194+
195+
custom_verify = MagicMock()
196+
197+
# Exercise
198+
with patch(
199+
"mcpauth.middleware.create_bearer_auth.create_bearer_auth"
200+
) as mock_create_bearer_auth:
201+
middleware_class = auth.bearer_auth_middleware(
202+
custom_verify, BaseBearerAuthConfig(required_scopes=["profile"])
203+
)
204+
205+
# Verify
206+
assert middleware_class is not None
207+
mock_create_bearer_auth.assert_called_once()
208+
args, kwargs = mock_create_bearer_auth.call_args
209+
assert args[0] == custom_verify
210+
assert kwargs == {}
211+
212+
def test_bearer_auth_middleware_jwt_without_jwks_uri(self):
213+
# Setup
214+
server_config = AuthServerConfig(
215+
type=AuthServerType.OAUTH,
216+
metadata=AuthorizationServerMetadata(
217+
issuer="https://example.com",
218+
authorization_endpoint="https://example.com/oauth/authorize",
219+
token_endpoint="https://example.com/oauth/token",
220+
# No jwks_uri
221+
response_types_supported=["code"],
222+
grant_types_supported=["authorization_code"],
223+
code_challenge_methods_supported=["S256"],
224+
),
225+
)
226+
config = MCPAuthConfig(server=server_config)
227+
auth = MCPAuth(config)
228+
229+
# Exercise & Verify
230+
with pytest.raises(MCPAuthAuthServerException) as exc_info:
231+
auth.bearer_auth_middleware(
232+
"jwt", BaseBearerAuthConfig(required_scopes=["profile"])
233+
)
234+
235+
assert exc_info.value.code == AuthServerExceptionCode.MISSING_JWKS_URI
236+
237+
def test_bearer_auth_middleware_invalid_mode(self):
238+
# Setup
239+
server_config = AuthServerConfig(
240+
type=AuthServerType.OAUTH,
241+
metadata=AuthorizationServerMetadata(
242+
issuer="https://example.com",
243+
authorization_endpoint="https://example.com/oauth/authorize",
244+
token_endpoint="https://example.com/oauth/token",
245+
response_types_supported=["code"],
246+
grant_types_supported=["authorization_code"],
247+
code_challenge_methods_supported=["S256"],
248+
),
249+
)
250+
config = MCPAuthConfig(server=server_config)
251+
auth = MCPAuth(config)
252+
253+
# Exercise & Verify
254+
with pytest.raises(ValueError) as exc_info:
255+
auth.bearer_auth_middleware(
256+
"invalid_mode", # type: ignore
257+
BaseBearerAuthConfig(required_scopes=["profile"]),
258+
)
259+
260+
assert "mode_or_verify must be 'jwt' or a callable function" in str(
261+
exc_info.value
262+
)
File renamed without changes.

0 commit comments

Comments
Β (0)