Skip to content

Commit f4ff346

Browse files
committed
fix
1 parent 2fa6d19 commit f4ff346

File tree

3 files changed

+95
-43
lines changed

3 files changed

+95
-43
lines changed

tests/issues/test_1027_win_unreachable_cleanup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212
import tempfile
1313
import textwrap
1414
from pathlib import Path
15+
from typing import TYPE_CHECKING
1516

1617
import anyio
1718
import pytest
1819

1920
from mcp import ClientSession, StdioServerParameters
2021
from mcp.client.stdio import _create_platform_compatible_process, stdio_client
2122

22-
from ..shared.test_win32_utils import escape_path_for_python
23+
# TODO(Marcelo): This doesn't seem to be the right path. We should fix this.
24+
if TYPE_CHECKING:
25+
from ..shared.test_win32_utils import escape_path_for_python
26+
else:
27+
from tests.shared.test_win32_utils import escape_path_for_python
2328

2429

2530
@pytest.mark.anyio

tests/server/auth/test_error_handling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import unittest.mock
6-
from typing import Any
6+
from typing import TYPE_CHECKING, Any
77
from urllib.parse import parse_qs, urlparse
88

99
import httpx
@@ -15,7 +15,11 @@
1515
from mcp.server.auth.provider import AuthorizeError, RegistrationError, TokenError
1616
from mcp.server.auth.routes import create_auth_routes
1717

18-
from ...server.fastmcp.auth.test_auth_integration import MockOAuthProvider
18+
# TODO(Marcelo): This TYPE_CHECKING shouldn't be here, but pytest doesn't seem to get the module correctly.
19+
if TYPE_CHECKING:
20+
from ...server.fastmcp.auth.test_auth_integration import MockOAuthProvider
21+
else:
22+
from tests.server.fastmcp.auth.test_auth_integration import MockOAuthProvider
1923

2024

2125
@pytest.fixture

tests/shared/test_streamable_http.py

Lines changed: 83 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def json_server_port() -> int:
337337
return s.getsockname()[1]
338338

339339

340-
@pytest.fixture(autouse=True)
340+
@pytest.fixture
341341
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
342342
"""Start a basic server."""
343343
proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True)
@@ -455,7 +455,7 @@ def json_server_url(json_server_port: int) -> str:
455455

456456

457457
# Basic request validation tests
458-
def test_accept_header_validation(basic_server_url: str):
458+
def test_accept_header_validation(basic_server: None, basic_server_url: str):
459459
"""Test that Accept header is properly validated."""
460460
# Test without Accept header
461461
response = requests.post(
@@ -467,7 +467,7 @@ def test_accept_header_validation(basic_server_url: str):
467467
assert "Not Acceptable" in response.text
468468

469469

470-
def test_content_type_validation(basic_server_url: str):
470+
def test_content_type_validation(basic_server: None, basic_server_url: str):
471471
"""Test that Content-Type header is properly validated."""
472472
# Test with incorrect Content-Type
473473
response = requests.post(
@@ -483,7 +483,7 @@ def test_content_type_validation(basic_server_url: str):
483483
assert "Invalid Content-Type" in response.text
484484

485485

486-
def test_json_validation(basic_server_url: str):
486+
def test_json_validation(basic_server: None, basic_server_url: str):
487487
"""Test that JSON content is properly validated."""
488488
# Test with invalid JSON
489489
response = requests.post(
@@ -498,7 +498,7 @@ def test_json_validation(basic_server_url: str):
498498
assert "Parse error" in response.text
499499

500500

501-
def test_json_parsing(basic_server_url: str):
501+
def test_json_parsing(basic_server: None, basic_server_url: str):
502502
"""Test that JSON content is properly parse."""
503503
# Test with valid JSON but invalid JSON-RPC
504504
response = requests.post(
@@ -513,7 +513,7 @@ def test_json_parsing(basic_server_url: str):
513513
assert "Validation error" in response.text
514514

515515

516-
def test_method_not_allowed(basic_server_url: str):
516+
def test_method_not_allowed(basic_server: None, basic_server_url: str):
517517
"""Test that unsupported HTTP methods are rejected."""
518518
# Test with unsupported method (PUT)
519519
response = requests.put(
@@ -528,7 +528,7 @@ def test_method_not_allowed(basic_server_url: str):
528528
assert "Method Not Allowed" in response.text
529529

530530

531-
def test_session_validation(basic_server_url: str):
531+
def test_session_validation(basic_server: None, basic_server_url: str):
532532
"""Test session ID validation."""
533533
# session_id not used directly in this test
534534

@@ -603,7 +603,7 @@ def test_streamable_http_transport_init_validation():
603603
StreamableHTTPServerTransport(mcp_session_id="test\n")
604604

605605

606-
def test_session_termination(basic_server_url: str):
606+
def test_session_termination(basic_server: None, basic_server_url: str):
607607
"""Test session termination via DELETE and subsequent request handling."""
608608
response = requests.post(
609609
f"{basic_server_url}/mcp",
@@ -643,7 +643,7 @@ def test_session_termination(basic_server_url: str):
643643
assert "Session has been terminated" in response.text
644644

645645

646-
def test_response(basic_server_url: str):
646+
def test_response(basic_server: None, basic_server_url: str):
647647
"""Test response handling for a valid request."""
648648
mcp_url = f"{basic_server_url}/mcp"
649649
response = requests.post(
@@ -693,7 +693,7 @@ def test_json_response(json_response_server: None, json_server_url: str):
693693
assert response.headers.get("Content-Type") == "application/json"
694694

695695

696-
def test_get_sse_stream(basic_server_url: str):
696+
def test_get_sse_stream(basic_server: None, basic_server_url: str):
697697
"""Test establishing an SSE stream via GET request."""
698698
# First, we need to initialize a session
699699
mcp_url = f"{basic_server_url}/mcp"
@@ -753,7 +753,7 @@ def test_get_sse_stream(basic_server_url: str):
753753
assert second_get.status_code == 409
754754

755755

756-
def test_get_validation(basic_server_url: str):
756+
def test_get_validation(basic_server: None, basic_server_url: str):
757757
"""Test validation for GET requests."""
758758
# First, we need to initialize a session
759759
mcp_url = f"{basic_server_url}/mcp"
@@ -808,14 +808,14 @@ def test_get_validation(basic_server_url: str):
808808

809809
# Client-specific fixtures
810810
@pytest.fixture
811-
async def http_client(basic_server_url: str):
811+
async def http_client(basic_server: None, basic_server_url: str):
812812
"""Create test client matching the SSE test pattern."""
813813
async with httpx.AsyncClient(base_url=basic_server_url) as client:
814814
yield client
815815

816816

817817
@pytest.fixture
818-
async def initialized_client_session(basic_server_url: str):
818+
async def initialized_client_session(basic_server: None, basic_server_url: str):
819819
"""Create initialized StreamableHTTP client session."""
820820
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
821821
read_stream,
@@ -831,7 +831,7 @@ async def initialized_client_session(basic_server_url: str):
831831

832832

833833
@pytest.mark.anyio
834-
async def test_streamablehttp_client_basic_connection(basic_server_url: str):
834+
async def test_streamablehttp_client_basic_connection(basic_server: None, basic_server_url: str):
835835
"""Test basic client connection with initialization."""
836836
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
837837
read_stream,
@@ -883,9 +883,13 @@ async def test_streamablehttp_client_error_handling(initialized_client_session:
883883

884884

885885
@pytest.mark.anyio
886-
async def test_streamablehttp_client_session_persistence(basic_server_url: str):
886+
async def test_streamablehttp_client_session_persistence(basic_server: None, basic_server_url: str):
887887
"""Test that session ID persists across requests."""
888-
async with streamablehttp_client(f"{basic_server_url}/mcp") as (read_stream, write_stream, _):
888+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
889+
read_stream,
890+
write_stream,
891+
_,
892+
):
889893
async with ClientSession(
890894
read_stream,
891895
write_stream,
@@ -909,7 +913,11 @@ async def test_streamablehttp_client_session_persistence(basic_server_url: str):
909913
@pytest.mark.anyio
910914
async def test_streamablehttp_client_json_response(json_response_server: None, json_server_url: str):
911915
"""Test client with JSON response mode."""
912-
async with streamablehttp_client(f"{json_server_url}/mcp") as (read_stream, write_stream, _):
916+
async with streamablehttp_client(f"{json_server_url}/mcp") as (
917+
read_stream,
918+
write_stream,
919+
_,
920+
):
913921
async with ClientSession(
914922
read_stream,
915923
write_stream,
@@ -931,7 +939,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j
931939

932940

933941
@pytest.mark.anyio
934-
async def test_streamablehttp_client_get_stream(basic_server_url: str):
942+
async def test_streamablehttp_client_get_stream(basic_server: None, basic_server_url: str):
935943
"""Test GET stream functionality for server-initiated messages."""
936944
import mcp.types as types
937945
from mcp.shared.session import RequestResponder
@@ -972,13 +980,17 @@ async def message_handler(
972980

973981

974982
@pytest.mark.anyio
975-
async def test_streamablehttp_client_session_termination(basic_server_url: str):
983+
async def test_streamablehttp_client_session_termination(basic_server: None, basic_server_url: str):
976984
"""Test client session termination functionality."""
977985

978986
captured_session_id = None
979987

980988
# Create the streamablehttp_client with a custom httpx client to capture headers
981-
async with streamablehttp_client(f"{basic_server_url}/mcp") as (read_stream, write_stream, get_session_id):
989+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
990+
read_stream,
991+
write_stream,
992+
get_session_id,
993+
):
982994
async with ClientSession(read_stream, write_stream) as session:
983995
# Initialize the session
984996
result = await session.initialize()
@@ -1009,7 +1021,9 @@ async def test_streamablehttp_client_session_termination(basic_server_url: str):
10091021

10101022

10111023
@pytest.mark.anyio
1012-
async def test_streamablehttp_client_session_termination_204(basic_server_url: str, monkeypatch: pytest.MonkeyPatch):
1024+
async def test_streamablehttp_client_session_termination_204(
1025+
basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch
1026+
):
10131027
"""Test client session termination functionality with a 204 response.
10141028
10151029
This test patches the httpx client to return a 204 response for DELETEs.
@@ -1192,12 +1206,13 @@ async def run_tool():
11921206

11931207
# We should have received the remaining notifications
11941208
assert len(captured_notifications) == 1
1209+
11951210
assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification)
11961211
assert captured_notifications[0].root.params.data == "Second notification after lock"
11971212

11981213

11991214
@pytest.mark.anyio
1200-
async def test_streamablehttp_server_sampling(basic_server_url: str):
1215+
async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str):
12011216
"""Test server-initiated sampling request through streamable HTTP transport."""
12021217
# Variable to track if sampling callback was invoked
12031218
sampling_callback_invoked = False
@@ -1224,7 +1239,11 @@ async def sampling_callback(
12241239
)
12251240

12261241
# Create client with sampling callback
1227-
async with streamablehttp_client(f"{basic_server_url}/mcp") as (read_stream, write_stream, _):
1242+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1243+
read_stream,
1244+
write_stream,
1245+
_,
1246+
):
12281247
async with ClientSession(
12291248
read_stream,
12301249
write_stream,
@@ -1284,12 +1303,7 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
12841303
headers_info = {}
12851304
if ctx.request and isinstance(ctx.request, Request):
12861305
headers_info = dict(ctx.request.headers)
1287-
return [
1288-
TextContent(
1289-
type="text",
1290-
text=json.dumps(headers_info),
1291-
)
1292-
]
1306+
return [TextContent(type="text", text=json.dumps(headers_info))]
12931307

12941308
elif name == "echo_context":
12951309
# Return full context information
@@ -1304,7 +1318,12 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
13041318
context_data["headers"] = dict(request.headers)
13051319
context_data["method"] = request.method
13061320
context_data["path"] = request.url.path
1307-
return [TextContent(type="text", text=json.dumps(context_data))]
1321+
return [
1322+
TextContent(
1323+
type="text",
1324+
text=json.dumps(context_data),
1325+
)
1326+
]
13081327

13091328
return [TextContent(type="text", text=f"Unknown tool: {name}")]
13101329

@@ -1314,16 +1333,28 @@ def run_context_aware_server(port: int):
13141333
"""Run the context-aware test server."""
13151334
server = ContextAwareServerTest()
13161335

1317-
session_manager = StreamableHTTPSessionManager(app=server, event_store=None, json_response=False)
1336+
session_manager = StreamableHTTPSessionManager(
1337+
app=server,
1338+
event_store=None,
1339+
json_response=False,
1340+
)
13181341

13191342
app = Starlette(
13201343
debug=True,
1321-
routes=[Mount("/mcp", app=session_manager.handle_request)],
1344+
routes=[
1345+
Mount("/mcp", app=session_manager.handle_request),
1346+
],
13221347
lifespan=lambda app: session_manager.run(),
13231348
)
13241349

1325-
config = uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error")
1326-
server_instance = uvicorn.Server(config=config)
1350+
server_instance = uvicorn.Server(
1351+
config=uvicorn.Config(
1352+
app=app,
1353+
host="127.0.0.1",
1354+
port=port,
1355+
log_level="error",
1356+
)
1357+
)
13271358
server_instance.run()
13281359

13291360

@@ -1425,7 +1456,11 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No
14251456
@pytest.mark.anyio
14261457
async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str):
14271458
"""Test that client includes mcp-protocol-version header after initialization."""
1428-
async with streamablehttp_client(f"{basic_server_url}/mcp") as (read_stream, write_stream, _):
1459+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1460+
read_stream,
1461+
write_stream,
1462+
_,
1463+
):
14291464
async with ClientSession(read_stream, write_stream) as session:
14301465
# Initialize and get the negotiated version
14311466
init_result = await session.initialize()
@@ -1443,7 +1478,7 @@ async def test_client_includes_protocol_version_header_after_init(context_aware_
14431478
assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version
14441479

14451480

1446-
def test_server_validates_protocol_version_header(basic_server_url: str):
1481+
def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str):
14471482
"""Test that server returns 400 Bad Request version if header unsupported or invalid."""
14481483
# First initialize a session to get a valid session ID
14491484
init_response = requests.post(
@@ -1501,7 +1536,7 @@ def test_server_validates_protocol_version_header(basic_server_url: str):
15011536
assert response.status_code == 200
15021537

15031538

1504-
def test_server_backwards_compatibility_no_protocol_version(basic_server_url: str):
1539+
def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str):
15051540
"""Test server accepts requests without protocol version header."""
15061541
# First initialize a session to get a valid session ID
15071542
init_response = requests.post(
@@ -1531,13 +1566,17 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server_url: st
15311566

15321567

15331568
@pytest.mark.anyio
1534-
async def test_client_crash_handled(basic_server_url: str):
1569+
async def test_client_crash_handled(basic_server: None, basic_server_url: str):
15351570
"""Test that cases where the client crashes are handled gracefully."""
15361571

15371572
# Simulate bad client that crashes after init
15381573
async def bad_client():
15391574
"""Client that triggers ClosedResourceError"""
1540-
async with streamablehttp_client(f"{basic_server_url}/mcp") as (read_stream, write_stream, _):
1575+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1576+
read_stream,
1577+
write_stream,
1578+
_,
1579+
):
15411580
async with ClientSession(read_stream, write_stream) as session:
15421581
await session.initialize()
15431582
raise Exception("client crash")
@@ -1551,7 +1590,11 @@ async def bad_client():
15511590
await anyio.sleep(0.1)
15521591

15531592
# Try a good client, it should still be able to connect and list tools
1554-
async with streamablehttp_client(f"{basic_server_url}/mcp") as (read_stream, write_stream, _):
1593+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1594+
read_stream,
1595+
write_stream,
1596+
_,
1597+
):
15551598
async with ClientSession(read_stream, write_stream) as session:
15561599
result = await session.initialize()
15571600
assert isinstance(result, InitializeResult)

0 commit comments

Comments
 (0)