|
11 | 11 | from typing import List, Optional |
12 | 12 | from urllib.parse import parse_qs, urlparse |
13 | 13 |
|
| 14 | +import anyio |
14 | 15 | import httpx |
15 | 16 | import pytest |
16 | 17 | from httpx_sse import aconnect_sse |
@@ -993,130 +994,132 @@ async def test_fastmcp_with_auth( |
993 | 994 | def test_tool(x: int) -> str: |
994 | 995 | return f"Result: {x}" |
995 | 996 |
|
996 | | - transport = StreamingASGITransport(app=mcp.starlette_app()) # pyright: ignore |
997 | | - test_client = httpx.AsyncClient( |
998 | | - transport=transport, base_url="http://mcptest.com" |
999 | | - ) |
1000 | | - # test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com") |
| 997 | + async with anyio.create_task_group() as task_group: |
| 998 | + transport = StreamingASGITransport(app=mcp.starlette_app(), task_group=task_group) # pyright: ignore |
| 999 | + test_client = httpx.AsyncClient( |
| 1000 | + transport=transport, base_url="http://mcptest.com" |
| 1001 | + ) |
| 1002 | + # test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com") |
1001 | 1003 |
|
1002 | | - # Test metadata endpoint |
1003 | | - response = await test_client.get("/.well-known/oauth-authorization-server") |
1004 | | - assert response.status_code == 200 |
| 1004 | + # Test metadata endpoint |
| 1005 | + response = await test_client.get("/.well-known/oauth-authorization-server") |
| 1006 | + assert response.status_code == 200 |
1005 | 1007 |
|
1006 | | - # Test that auth is required for protected endpoints |
1007 | | - response = await test_client.get("/sse") |
1008 | | - # TODO: we should return 401/403 depending on whether authn or authz fails |
1009 | | - assert response.status_code == 403 |
| 1008 | + # Test that auth is required for protected endpoints |
| 1009 | + response = await test_client.get("/sse") |
| 1010 | + # TODO: we should return 401/403 depending on whether authn or authz fails |
| 1011 | + assert response.status_code == 403 |
1010 | 1012 |
|
1011 | | - response = await test_client.post("/messages/") |
1012 | | - # TODO: we should return 401/403 depending on whether authn or authz fails |
1013 | | - assert response.status_code == 403, response.content |
| 1013 | + response = await test_client.post("/messages/") |
| 1014 | + # TODO: we should return 401/403 depending on whether authn or authz fails |
| 1015 | + assert response.status_code == 403, response.content |
1014 | 1016 |
|
1015 | | - response = await test_client.post( |
1016 | | - "/messages/", |
1017 | | - headers={"Authorization": "invalid"}, |
1018 | | - ) |
1019 | | - assert response.status_code == 403 |
1020 | | - |
1021 | | - response = await test_client.post( |
1022 | | - "/messages/", |
1023 | | - headers={"Authorization": "Bearer invalid"}, |
1024 | | - ) |
1025 | | - assert response.status_code == 403 |
| 1017 | + response = await test_client.post( |
| 1018 | + "/messages/", |
| 1019 | + headers={"Authorization": "invalid"}, |
| 1020 | + ) |
| 1021 | + assert response.status_code == 403 |
1026 | 1022 |
|
1027 | | - # now, become authenticated and try to go through the flow again |
1028 | | - client_metadata = { |
1029 | | - "redirect_uris": ["https://client.example.com/callback"], |
1030 | | - "client_name": "Test Client", |
1031 | | - } |
| 1023 | + response = await test_client.post( |
| 1024 | + "/messages/", |
| 1025 | + headers={"Authorization": "Bearer invalid"}, |
| 1026 | + ) |
| 1027 | + assert response.status_code == 403 |
1032 | 1028 |
|
1033 | | - response = await test_client.post( |
1034 | | - "/register", |
1035 | | - json=client_metadata, |
1036 | | - ) |
1037 | | - assert response.status_code == 201 |
1038 | | - client_info = response.json() |
| 1029 | + # now, become authenticated and try to go through the flow again |
| 1030 | + client_metadata = { |
| 1031 | + "redirect_uris": ["https://client.example.com/callback"], |
| 1032 | + "client_name": "Test Client", |
| 1033 | + } |
1039 | 1034 |
|
1040 | | - # Request authorization using POST with form-encoded data |
1041 | | - response = await test_client.post( |
1042 | | - "/authorize", |
1043 | | - data={ |
1044 | | - "response_type": "code", |
1045 | | - "client_id": client_info["client_id"], |
1046 | | - "redirect_uri": "https://client.example.com/callback", |
1047 | | - "code_challenge": pkce_challenge["code_challenge"], |
1048 | | - "code_challenge_method": "S256", |
1049 | | - "state": "test_state", |
1050 | | - }, |
1051 | | - ) |
1052 | | - assert response.status_code == 302 |
| 1035 | + response = await test_client.post( |
| 1036 | + "/register", |
| 1037 | + json=client_metadata, |
| 1038 | + ) |
| 1039 | + assert response.status_code == 201 |
| 1040 | + client_info = response.json() |
1053 | 1041 |
|
1054 | | - # Extract the authorization code from the redirect URL |
1055 | | - redirect_url = response.headers["location"] |
1056 | | - parsed_url = urlparse(redirect_url) |
1057 | | - query_params = parse_qs(parsed_url.query) |
| 1042 | + # Request authorization using POST with form-encoded data |
| 1043 | + response = await test_client.post( |
| 1044 | + "/authorize", |
| 1045 | + data={ |
| 1046 | + "response_type": "code", |
| 1047 | + "client_id": client_info["client_id"], |
| 1048 | + "redirect_uri": "https://client.example.com/callback", |
| 1049 | + "code_challenge": pkce_challenge["code_challenge"], |
| 1050 | + "code_challenge_method": "S256", |
| 1051 | + "state": "test_state", |
| 1052 | + }, |
| 1053 | + ) |
| 1054 | + assert response.status_code == 302 |
1058 | 1055 |
|
1059 | | - assert "code" in query_params |
1060 | | - auth_code = query_params["code"][0] |
| 1056 | + # Extract the authorization code from the redirect URL |
| 1057 | + redirect_url = response.headers["location"] |
| 1058 | + parsed_url = urlparse(redirect_url) |
| 1059 | + query_params = parse_qs(parsed_url.query) |
1061 | 1060 |
|
1062 | | - # Exchange the authorization code for tokens |
1063 | | - response = await test_client.post( |
1064 | | - "/token", |
1065 | | - data={ |
1066 | | - "grant_type": "authorization_code", |
1067 | | - "client_id": client_info["client_id"], |
1068 | | - "client_secret": client_info["client_secret"], |
1069 | | - "code": auth_code, |
1070 | | - "code_verifier": pkce_challenge["code_verifier"], |
1071 | | - "redirect_uri": "https://client.example.com/callback", |
1072 | | - }, |
1073 | | - ) |
1074 | | - assert response.status_code == 200 |
| 1061 | + assert "code" in query_params |
| 1062 | + auth_code = query_params["code"][0] |
1075 | 1063 |
|
1076 | | - token_response = response.json() |
1077 | | - assert "access_token" in token_response |
1078 | | - authorization = f"Bearer {token_response['access_token']}" |
1079 | | - |
1080 | | - # Test the authenticated endpoint with valid token |
1081 | | - async with aconnect_sse( |
1082 | | - test_client, "GET", "/sse", headers={"Authorization": authorization} |
1083 | | - ) as event_source: |
1084 | | - assert event_source.response.status_code == 200 |
1085 | | - events = event_source.aiter_sse() |
1086 | | - sse = await events.__anext__() |
1087 | | - assert sse.event == "endpoint" |
1088 | | - assert sse.data.startswith("/messages/?session_id=") |
1089 | | - messages_uri = sse.data |
1090 | | - |
1091 | | - # verify that we can now post to the /messages endpoint, and get a response |
1092 | | - # on the /sse endpoint |
| 1064 | + # Exchange the authorization code for tokens |
1093 | 1065 | response = await test_client.post( |
1094 | | - messages_uri, |
1095 | | - headers={"Authorization": authorization}, |
1096 | | - content=JSONRPCRequest( |
1097 | | - jsonrpc="2.0", |
1098 | | - id="123", |
1099 | | - method="initialize", |
1100 | | - params={ |
1101 | | - "protocolVersion": "2024-11-05", |
1102 | | - "capabilities": { |
1103 | | - "roots": {"listChanged": True}, |
1104 | | - "sampling": {}, |
1105 | | - }, |
1106 | | - "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, |
1107 | | - }, |
1108 | | - ).model_dump_json(), |
1109 | | - ) |
1110 | | - assert response.status_code == 202 |
1111 | | - assert response.content == b"Accepted" |
1112 | | - |
1113 | | - sse = await events.__anext__() |
1114 | | - assert sse.event == "message" |
1115 | | - sse_data = json.loads(sse.data) |
1116 | | - assert sse_data["id"] == "123" |
1117 | | - assert set(sse_data["result"]["capabilities"].keys()) == set( |
1118 | | - ("experimental", "prompts", "resources", "tools") |
| 1066 | + "/token", |
| 1067 | + data={ |
| 1068 | + "grant_type": "authorization_code", |
| 1069 | + "client_id": client_info["client_id"], |
| 1070 | + "client_secret": client_info["client_secret"], |
| 1071 | + "code": auth_code, |
| 1072 | + "code_verifier": pkce_challenge["code_verifier"], |
| 1073 | + "redirect_uri": "https://client.example.com/callback", |
| 1074 | + }, |
1119 | 1075 | ) |
| 1076 | + assert response.status_code == 200 |
| 1077 | + |
| 1078 | + token_response = response.json() |
| 1079 | + assert "access_token" in token_response |
| 1080 | + authorization = f"Bearer {token_response['access_token']}" |
| 1081 | + |
| 1082 | + # Test the authenticated endpoint with valid token |
| 1083 | + async with aconnect_sse( |
| 1084 | + test_client, "GET", "/sse", headers={"Authorization": authorization} |
| 1085 | + ) as event_source: |
| 1086 | + assert event_source.response.status_code == 200 |
| 1087 | + events = event_source.aiter_sse() |
| 1088 | + sse = await events.__anext__() |
| 1089 | + assert sse.event == "endpoint" |
| 1090 | + assert sse.data.startswith("/messages/?session_id=") |
| 1091 | + messages_uri = sse.data |
| 1092 | + |
| 1093 | + # verify that we can now post to the /messages endpoint, and get a response |
| 1094 | + # on the /sse endpoint |
| 1095 | + response = await test_client.post( |
| 1096 | + messages_uri, |
| 1097 | + headers={"Authorization": authorization}, |
| 1098 | + content=JSONRPCRequest( |
| 1099 | + jsonrpc="2.0", |
| 1100 | + id="123", |
| 1101 | + method="initialize", |
| 1102 | + params={ |
| 1103 | + "protocolVersion": "2024-11-05", |
| 1104 | + "capabilities": { |
| 1105 | + "roots": {"listChanged": True}, |
| 1106 | + "sampling": {}, |
| 1107 | + }, |
| 1108 | + "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, |
| 1109 | + }, |
| 1110 | + ).model_dump_json(), |
| 1111 | + ) |
| 1112 | + assert response.status_code == 202 |
| 1113 | + assert response.content == b"Accepted" |
| 1114 | + |
| 1115 | + sse = await events.__anext__() |
| 1116 | + assert sse.event == "message" |
| 1117 | + sse_data = json.loads(sse.data) |
| 1118 | + assert sse_data["id"] == "123" |
| 1119 | + assert set(sse_data["result"]["capabilities"].keys()) == set( |
| 1120 | + ("experimental", "prompts", "resources", "tools") |
| 1121 | + ) |
| 1122 | + task_group.cancel_scope.cancel() |
1120 | 1123 |
|
1121 | 1124 |
|
1122 | 1125 | class TestAuthorizeEndpointErrors: |
|
0 commit comments