|
4 | 4 |
|
5 | 5 | import base64 |
6 | 6 | import hashlib |
7 | | -import json |
8 | 7 | import secrets |
9 | 8 | import time |
10 | 9 | import unittest.mock |
11 | 10 | from urllib.parse import parse_qs, urlparse |
12 | 11 |
|
13 | | -import anyio |
14 | 12 | import httpx |
15 | 13 | import pytest |
16 | | -from httpx_sse import aconnect_sse |
17 | 14 | from pydantic import AnyHttpUrl |
18 | 15 | from starlette.applications import Starlette |
19 | 16 |
|
|
30 | 27 | RevocationOptions, |
31 | 28 | create_auth_routes, |
32 | 29 | ) |
33 | | -from mcp.server.auth.settings import AuthSettings |
34 | | -from mcp.server.fastmcp import FastMCP |
35 | | -from mcp.server.streaming_asgi_transport import StreamingASGITransport |
36 | 30 | from mcp.shared.auth import ( |
37 | 31 | OAuthClientInformationFull, |
38 | 32 | OAuthToken, |
39 | 33 | ) |
40 | | -from mcp.types import JSONRPCRequest |
41 | 34 |
|
42 | 35 |
|
43 | 36 | # Mock OAuth provider for testing |
@@ -230,10 +223,11 @@ def auth_app(mock_oauth_provider): |
230 | 223 |
|
231 | 224 |
|
232 | 225 | @pytest.fixture |
233 | | -def test_client(auth_app) -> httpx.AsyncClient: |
234 | | - return httpx.AsyncClient( |
| 226 | +async def test_client(auth_app): |
| 227 | + async with httpx.AsyncClient( |
235 | 228 | transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com" |
236 | | - ) |
| 229 | + ) as client: |
| 230 | + yield client |
237 | 231 |
|
238 | 232 |
|
239 | 233 | @pytest.fixture |
@@ -993,163 +987,7 @@ async def test_client_registration_invalid_grant_type( |
993 | 987 | ) |
994 | 988 |
|
995 | 989 |
|
996 | | -class TestFastMCPWithAuth: |
997 | | - """Test FastMCP server with authentication.""" |
998 | | - |
999 | | - @pytest.mark.anyio |
1000 | | - async def test_fastmcp_with_auth( |
1001 | | - self, mock_oauth_provider: MockOAuthProvider, pkce_challenge |
1002 | | - ): |
1003 | | - """Test creating a FastMCP server with authentication.""" |
1004 | | - # Create FastMCP server with auth provider |
1005 | | - mcp = FastMCP( |
1006 | | - auth_server_provider=mock_oauth_provider, |
1007 | | - require_auth=True, |
1008 | | - auth=AuthSettings( |
1009 | | - issuer_url=AnyHttpUrl("https://auth.example.com"), |
1010 | | - client_registration_options=ClientRegistrationOptions(enabled=True), |
1011 | | - revocation_options=RevocationOptions(enabled=True), |
1012 | | - required_scopes=["read", "write"], |
1013 | | - ), |
1014 | | - ) |
1015 | | - |
1016 | | - # Add a test tool |
1017 | | - @mcp.tool() |
1018 | | - def test_tool(x: int) -> str: |
1019 | | - return f"Result: {x}" |
1020 | | - |
1021 | | - async with anyio.create_task_group() as task_group: |
1022 | | - transport = StreamingASGITransport( |
1023 | | - app=mcp.sse_app(), |
1024 | | - task_group=task_group, |
1025 | | - ) |
1026 | | - test_client = httpx.AsyncClient( |
1027 | | - transport=transport, base_url="http://mcptest.com" |
1028 | | - ) |
1029 | | - |
1030 | | - # Test metadata endpoint |
1031 | | - response = await test_client.get("/.well-known/oauth-authorization-server") |
1032 | | - assert response.status_code == 200 |
1033 | 990 |
|
1034 | | - # Test that auth is required for protected endpoints |
1035 | | - response = await test_client.get("/sse") |
1036 | | - assert response.status_code == 401 |
1037 | | - |
1038 | | - response = await test_client.post("/messages/") |
1039 | | - assert response.status_code == 401, response.content |
1040 | | - |
1041 | | - response = await test_client.post( |
1042 | | - "/messages/", |
1043 | | - headers={"Authorization": "invalid"}, |
1044 | | - ) |
1045 | | - assert response.status_code == 401 |
1046 | | - |
1047 | | - response = await test_client.post( |
1048 | | - "/messages/", |
1049 | | - headers={"Authorization": "Bearer invalid"}, |
1050 | | - ) |
1051 | | - assert response.status_code == 401 |
1052 | | - |
1053 | | - # now, become authenticated and try to go through the flow again |
1054 | | - client_metadata = { |
1055 | | - "redirect_uris": ["https://client.example.com/callback"], |
1056 | | - "client_name": "Test Client", |
1057 | | - } |
1058 | | - |
1059 | | - response = await test_client.post( |
1060 | | - "/register", |
1061 | | - json=client_metadata, |
1062 | | - ) |
1063 | | - assert response.status_code == 201 |
1064 | | - client_info = response.json() |
1065 | | - |
1066 | | - # Request authorization using POST with form-encoded data |
1067 | | - response = await test_client.post( |
1068 | | - "/authorize", |
1069 | | - data={ |
1070 | | - "response_type": "code", |
1071 | | - "client_id": client_info["client_id"], |
1072 | | - "redirect_uri": "https://client.example.com/callback", |
1073 | | - "code_challenge": pkce_challenge["code_challenge"], |
1074 | | - "code_challenge_method": "S256", |
1075 | | - "state": "test_state", |
1076 | | - }, |
1077 | | - ) |
1078 | | - assert response.status_code == 302 |
1079 | | - |
1080 | | - # Extract the authorization code from the redirect URL |
1081 | | - redirect_url = response.headers["location"] |
1082 | | - parsed_url = urlparse(redirect_url) |
1083 | | - query_params = parse_qs(parsed_url.query) |
1084 | | - |
1085 | | - assert "code" in query_params |
1086 | | - auth_code = query_params["code"][0] |
1087 | | - |
1088 | | - # Exchange the authorization code for tokens |
1089 | | - response = await test_client.post( |
1090 | | - "/token", |
1091 | | - data={ |
1092 | | - "grant_type": "authorization_code", |
1093 | | - "client_id": client_info["client_id"], |
1094 | | - "client_secret": client_info["client_secret"], |
1095 | | - "code": auth_code, |
1096 | | - "code_verifier": pkce_challenge["code_verifier"], |
1097 | | - "redirect_uri": "https://client.example.com/callback", |
1098 | | - }, |
1099 | | - ) |
1100 | | - assert response.status_code == 200 |
1101 | | - |
1102 | | - token_response = response.json() |
1103 | | - assert "access_token" in token_response |
1104 | | - authorization = f"Bearer {token_response['access_token']}" |
1105 | | - |
1106 | | - # Test the authenticated endpoint with valid token |
1107 | | - async with aconnect_sse( |
1108 | | - test_client, "GET", "/sse", headers={"Authorization": authorization} |
1109 | | - ) as event_source: |
1110 | | - assert event_source.response.status_code == 200 |
1111 | | - events = event_source.aiter_sse() |
1112 | | - sse = await events.__anext__() |
1113 | | - assert sse.event == "endpoint" |
1114 | | - assert sse.data.startswith("/messages/?session_id=") |
1115 | | - messages_uri = sse.data |
1116 | | - |
1117 | | - # verify that we can now post to the /messages endpoint, |
1118 | | - # and get a response on the /sse endpoint |
1119 | | - response = await test_client.post( |
1120 | | - messages_uri, |
1121 | | - headers={"Authorization": authorization}, |
1122 | | - content=JSONRPCRequest( |
1123 | | - jsonrpc="2.0", |
1124 | | - id="123", |
1125 | | - method="initialize", |
1126 | | - params={ |
1127 | | - "protocolVersion": "2024-11-05", |
1128 | | - "capabilities": { |
1129 | | - "roots": {"listChanged": True}, |
1130 | | - "sampling": {}, |
1131 | | - }, |
1132 | | - "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, |
1133 | | - }, |
1134 | | - ).model_dump_json(), |
1135 | | - ) |
1136 | | - assert response.status_code == 202 |
1137 | | - assert response.content == b"Accepted" |
1138 | | - |
1139 | | - sse = await events.__anext__() |
1140 | | - assert sse.event == "message" |
1141 | | - sse_data = json.loads(sse.data) |
1142 | | - assert sse_data["id"] == "123" |
1143 | | - assert set(sse_data["result"]["capabilities"].keys()) == { |
1144 | | - "experimental", |
1145 | | - "prompts", |
1146 | | - "resources", |
1147 | | - "tools", |
1148 | | - } |
1149 | | - # the /sse endpoint will never finish; normally, the client could just |
1150 | | - # disconnect, but in tests the easiest way to do this is to cancel the |
1151 | | - # task group |
1152 | | - task_group.cancel_scope.cancel() |
1153 | 991 |
|
1154 | 992 |
|
1155 | 993 | class TestAuthorizeEndpointErrors: |
|
0 commit comments