Skip to content

Commit 4be8dbc

Browse files
committed
feat: delegated middleware
1 parent 8fb73f3 commit 4be8dbc

File tree

2 files changed

+191
-13
lines changed

2 files changed

+191
-13
lines changed

mcpauth/__init__.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,57 @@
1-
from .exceptions import (
2-
MCPAuthException as MCPAuthException,
3-
MCPAuthConfigException as MCPAuthConfigException,
4-
AuthServerExceptionCode as AuthServerExceptionCode,
5-
MCPAuthAuthServerException as MCPAuthAuthServerException,
6-
BearerAuthExceptionCode as BearerAuthExceptionCode,
7-
MCPAuthBearerAuthExceptionDetails as MCPAuthBearerAuthExceptionDetails,
8-
MCPAuthBearerAuthException as MCPAuthBearerAuthException,
9-
MCPAuthJwtVerificationExceptionCode as MCPAuthJwtVerificationExceptionCode,
10-
MCPAuthJwtVerificationException as MCPAuthJwtVerificationException,
11-
)
1+
import logging
2+
3+
from .utils.fetch_server_config import ServerMetadataPaths
4+
from .config import MCPAuthConfig
5+
from .exceptions import MCPAuthAuthServerException, AuthServerExceptionCode
6+
from .utils.validate_server_config import validate_server_config
7+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
8+
from starlette.requests import Request
9+
from starlette.responses import Response, JSONResponse
1210

1311

1412
class MCPAuth:
15-
def __init__(self):
16-
self.config = None
13+
def __init__(self, config: MCPAuthConfig):
14+
result = validate_server_config(config.server)
15+
16+
if not result.is_valid:
17+
raise MCPAuthAuthServerException(
18+
AuthServerExceptionCode.INVALID_SERVER_CONFIG, cause=result
19+
)
20+
21+
if len(result.warnings) > 0:
22+
logging.warning("The authorization server configuration has warnings:\n")
23+
for warning in result.warnings:
24+
logging.warning(f"- {warning}")
25+
26+
self.config = config
27+
28+
def delegated_middleware(self) -> type[BaseHTTPMiddleware]:
29+
"""
30+
Returns a middleware that handles OAuth 2.0 Authorization Metadata endpoint
31+
(`/.well-known/oauth-authorization-server`) with CORS support (delegated mode).
32+
33+
:return: A middleware class that can be used in a Starlette or FastAPI application.
34+
"""
35+
server_config = self.config.server
36+
37+
class DelegatedMiddleware(BaseHTTPMiddleware):
38+
async def dispatch(
39+
self, request: Request, call_next: RequestResponseEndpoint
40+
) -> Response:
41+
path = request.url.path
42+
if path == ServerMetadataPaths.OAUTH:
43+
response = JSONResponse(
44+
{
45+
k: v
46+
for k, v in server_config.metadata.model_dump().items()
47+
if v is not None
48+
},
49+
status_code=200,
50+
)
51+
response.headers["Access-Control-Allow-Origin"] = "*"
52+
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
53+
return response
54+
else:
55+
return await call_next(request)
56+
57+
return DelegatedMiddleware

mcpauth/__init__test.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import pytest
2+
from unittest.mock import AsyncMock, patch, MagicMock
3+
from starlette.requests import Request
4+
from starlette.responses import Response
5+
from mcpauth import MCPAuth, MCPAuthAuthServerException, AuthServerExceptionCode
6+
from mcpauth.config import MCPAuthConfig
7+
from mcpauth.models.auth_server import AuthServerConfig, AuthServerType
8+
from mcpauth.models.oauth import AuthorizationServerMetadata
9+
10+
11+
class TestMCPAuth:
12+
def test_init_with_valid_config(self):
13+
# Setup
14+
server_config = AuthServerConfig(
15+
type=AuthServerType.OAUTH,
16+
metadata=AuthorizationServerMetadata(
17+
issuer="https://example.com",
18+
authorization_endpoint="https://example.com/oauth/authorize",
19+
token_endpoint="https://example.com/oauth/token",
20+
response_types_supported=["code"],
21+
grant_types_supported=["authorization_code"],
22+
code_challenge_methods_supported=["S256"],
23+
),
24+
)
25+
config = MCPAuthConfig(server=server_config)
26+
27+
# Exercise
28+
auth = MCPAuth(config)
29+
30+
# Verify
31+
assert auth.config == config
32+
33+
def test_init_with_invalid_config(self):
34+
# Setup
35+
server_config = AuthServerConfig(
36+
type=AuthServerType.OAUTH,
37+
metadata=AuthorizationServerMetadata(
38+
issuer="https://example.com",
39+
authorization_endpoint="https://example.com/oauth/authorize",
40+
token_endpoint="https://example.com/oauth/token",
41+
response_types_supported=["token"], # Invalid response type
42+
),
43+
)
44+
config = MCPAuthConfig(server=server_config)
45+
46+
# Exercise & Verify
47+
with pytest.raises(MCPAuthAuthServerException) as exc_info:
48+
MCPAuth(config)
49+
50+
assert exc_info.value.code == AuthServerExceptionCode.INVALID_SERVER_CONFIG
51+
52+
@patch("mcpauth.logging.warning")
53+
def test_init_with_warnings(self, mock_warning: MagicMock):
54+
# Setup
55+
server_config = AuthServerConfig(
56+
type=AuthServerType.OAUTH,
57+
metadata=AuthorizationServerMetadata(
58+
issuer="https://example.com",
59+
authorization_endpoint="https://example.com/oauth/authorize",
60+
token_endpoint="https://example.com/oauth/token",
61+
response_types_supported=["code"],
62+
grant_types_supported=["authorization_code"],
63+
code_challenge_methods_supported=["S256"],
64+
# Missing registration_endpoint will cause a warning
65+
),
66+
)
67+
config = MCPAuthConfig(server=server_config)
68+
69+
# Exercise
70+
MCPAuth(config)
71+
72+
# Verify
73+
assert mock_warning.called
74+
75+
@pytest.mark.asyncio
76+
async def test_delegated_middleware_oauth_endpoint(self):
77+
# Setup
78+
server_config = AuthServerConfig(
79+
type=AuthServerType.OAUTH,
80+
metadata=AuthorizationServerMetadata(
81+
issuer="https://example.com",
82+
authorization_endpoint="https://example.com/oauth/authorize",
83+
token_endpoint="https://example.com/oauth/token",
84+
response_types_supported=["code"],
85+
grant_types_supported=["authorization_code"],
86+
code_challenge_methods_supported=["S256"],
87+
),
88+
)
89+
config = MCPAuthConfig(server=server_config)
90+
auth = MCPAuth(config)
91+
92+
middleware_class = auth.delegated_middleware()
93+
middleware = middleware_class(app=MagicMock())
94+
95+
mock_request = MagicMock(spec=Request)
96+
mock_request.url.path = "/.well-known/oauth-authorization-server"
97+
98+
# Exercise
99+
response = await middleware.dispatch(mock_request, call_next=AsyncMock())
100+
101+
# Verify
102+
assert response.status_code == 200
103+
assert response.headers["Access-Control-Allow-Origin"] == "*"
104+
assert response.headers["Access-Control-Allow-Methods"] == "GET, OPTIONS"
105+
106+
@pytest.mark.asyncio
107+
async def test_delegated_middleware_other_endpoint(self):
108+
# Setup
109+
server_config = AuthServerConfig(
110+
type=AuthServerType.OAUTH,
111+
metadata=AuthorizationServerMetadata(
112+
issuer="https://example.com",
113+
authorization_endpoint="https://example.com/oauth/authorize",
114+
token_endpoint="https://example.com/oauth/token",
115+
response_types_supported=["code"],
116+
grant_types_supported=["authorization_code"],
117+
code_challenge_methods_supported=["S256"],
118+
),
119+
)
120+
config = MCPAuthConfig(server=server_config)
121+
auth = MCPAuth(config)
122+
123+
middleware_class = auth.delegated_middleware()
124+
middleware = middleware_class(app=MagicMock())
125+
126+
mock_request = MagicMock(spec=Request)
127+
mock_request.url.path = "/some-other-path"
128+
129+
mock_response = Response(content="Test response")
130+
mock_call_next = AsyncMock(return_value=mock_response)
131+
132+
# Exercise
133+
response = await middleware.dispatch(mock_request, call_next=mock_call_next)
134+
135+
# Verify
136+
assert mock_call_next.called
137+
assert response == mock_response

0 commit comments

Comments
 (0)