@@ -337,7 +337,7 @@ def json_server_port() -> int:
337337 return s .getsockname ()[1 ]
338338
339339
340- @pytest .fixture ( autouse = True )
340+ @pytest .fixture
341341def basic_server (basic_server_port : int ) -> Generator [None , None , None ]:
342342 """Start a basic server."""
343343 proc = multiprocessing .Process (target = run_server , kwargs = {"port" : basic_server_port }, daemon = True )
@@ -455,7 +455,7 @@ def json_server_url(json_server_port: int) -> str:
455455
456456
457457# Basic request validation tests
458- def test_accept_header_validation (basic_server_url : str ):
458+ def test_accept_header_validation (basic_server : None , basic_server_url : str ):
459459 """Test that Accept header is properly validated."""
460460 # Test without Accept header
461461 response = requests .post (
@@ -467,7 +467,7 @@ def test_accept_header_validation(basic_server_url: str):
467467 assert "Not Acceptable" in response .text
468468
469469
470- def test_content_type_validation (basic_server_url : str ):
470+ def test_content_type_validation (basic_server : None , basic_server_url : str ):
471471 """Test that Content-Type header is properly validated."""
472472 # Test with incorrect Content-Type
473473 response = requests .post (
@@ -483,7 +483,7 @@ def test_content_type_validation(basic_server_url: str):
483483 assert "Invalid Content-Type" in response .text
484484
485485
486- def test_json_validation (basic_server_url : str ):
486+ def test_json_validation (basic_server : None , basic_server_url : str ):
487487 """Test that JSON content is properly validated."""
488488 # Test with invalid JSON
489489 response = requests .post (
@@ -498,7 +498,7 @@ def test_json_validation(basic_server_url: str):
498498 assert "Parse error" in response .text
499499
500500
501- def test_json_parsing (basic_server_url : str ):
501+ def test_json_parsing (basic_server : None , basic_server_url : str ):
502502 """Test that JSON content is properly parse."""
503503 # Test with valid JSON but invalid JSON-RPC
504504 response = requests .post (
@@ -513,7 +513,7 @@ def test_json_parsing(basic_server_url: str):
513513 assert "Validation error" in response .text
514514
515515
516- def test_method_not_allowed (basic_server_url : str ):
516+ def test_method_not_allowed (basic_server : None , basic_server_url : str ):
517517 """Test that unsupported HTTP methods are rejected."""
518518 # Test with unsupported method (PUT)
519519 response = requests .put (
@@ -528,7 +528,7 @@ def test_method_not_allowed(basic_server_url: str):
528528 assert "Method Not Allowed" in response .text
529529
530530
531- def test_session_validation (basic_server_url : str ):
531+ def test_session_validation (basic_server : None , basic_server_url : str ):
532532 """Test session ID validation."""
533533 # session_id not used directly in this test
534534
@@ -603,7 +603,7 @@ def test_streamable_http_transport_init_validation():
603603 StreamableHTTPServerTransport (mcp_session_id = "test\n " )
604604
605605
606- def test_session_termination (basic_server_url : str ):
606+ def test_session_termination (basic_server : None , basic_server_url : str ):
607607 """Test session termination via DELETE and subsequent request handling."""
608608 response = requests .post (
609609 f"{ basic_server_url } /mcp" ,
@@ -643,7 +643,7 @@ def test_session_termination(basic_server_url: str):
643643 assert "Session has been terminated" in response .text
644644
645645
646- def test_response (basic_server_url : str ):
646+ def test_response (basic_server : None , basic_server_url : str ):
647647 """Test response handling for a valid request."""
648648 mcp_url = f"{ basic_server_url } /mcp"
649649 response = requests .post (
@@ -693,7 +693,7 @@ def test_json_response(json_response_server: None, json_server_url: str):
693693 assert response .headers .get ("Content-Type" ) == "application/json"
694694
695695
696- def test_get_sse_stream (basic_server_url : str ):
696+ def test_get_sse_stream (basic_server : None , basic_server_url : str ):
697697 """Test establishing an SSE stream via GET request."""
698698 # First, we need to initialize a session
699699 mcp_url = f"{ basic_server_url } /mcp"
@@ -753,7 +753,7 @@ def test_get_sse_stream(basic_server_url: str):
753753 assert second_get .status_code == 409
754754
755755
756- def test_get_validation (basic_server_url : str ):
756+ def test_get_validation (basic_server : None , basic_server_url : str ):
757757 """Test validation for GET requests."""
758758 # First, we need to initialize a session
759759 mcp_url = f"{ basic_server_url } /mcp"
@@ -808,14 +808,14 @@ def test_get_validation(basic_server_url: str):
808808
809809# Client-specific fixtures
810810@pytest .fixture
811- async def http_client (basic_server_url : str ):
811+ async def http_client (basic_server : None , basic_server_url : str ):
812812 """Create test client matching the SSE test pattern."""
813813 async with httpx .AsyncClient (base_url = basic_server_url ) as client :
814814 yield client
815815
816816
817817@pytest .fixture
818- async def initialized_client_session (basic_server_url : str ):
818+ async def initialized_client_session (basic_server : None , basic_server_url : str ):
819819 """Create initialized StreamableHTTP client session."""
820820 async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (
821821 read_stream ,
@@ -831,7 +831,7 @@ async def initialized_client_session(basic_server_url: str):
831831
832832
833833@pytest .mark .anyio
834- async def test_streamablehttp_client_basic_connection (basic_server_url : str ):
834+ async def test_streamablehttp_client_basic_connection (basic_server : None , basic_server_url : str ):
835835 """Test basic client connection with initialization."""
836836 async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (
837837 read_stream ,
@@ -883,9 +883,13 @@ async def test_streamablehttp_client_error_handling(initialized_client_session:
883883
884884
885885@pytest .mark .anyio
886- async def test_streamablehttp_client_session_persistence (basic_server_url : str ):
886+ async def test_streamablehttp_client_session_persistence (basic_server : None , basic_server_url : str ):
887887 """Test that session ID persists across requests."""
888- async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (read_stream , write_stream , _ ):
888+ async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (
889+ read_stream ,
890+ write_stream ,
891+ _ ,
892+ ):
889893 async with ClientSession (
890894 read_stream ,
891895 write_stream ,
@@ -909,7 +913,11 @@ async def test_streamablehttp_client_session_persistence(basic_server_url: str):
909913@pytest .mark .anyio
910914async def test_streamablehttp_client_json_response (json_response_server : None , json_server_url : str ):
911915 """Test client with JSON response mode."""
912- async with streamablehttp_client (f"{ json_server_url } /mcp" ) as (read_stream , write_stream , _ ):
916+ async with streamablehttp_client (f"{ json_server_url } /mcp" ) as (
917+ read_stream ,
918+ write_stream ,
919+ _ ,
920+ ):
913921 async with ClientSession (
914922 read_stream ,
915923 write_stream ,
@@ -931,7 +939,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j
931939
932940
933941@pytest .mark .anyio
934- async def test_streamablehttp_client_get_stream (basic_server_url : str ):
942+ async def test_streamablehttp_client_get_stream (basic_server : None , basic_server_url : str ):
935943 """Test GET stream functionality for server-initiated messages."""
936944 import mcp .types as types
937945 from mcp .shared .session import RequestResponder
@@ -972,13 +980,17 @@ async def message_handler(
972980
973981
974982@pytest .mark .anyio
975- async def test_streamablehttp_client_session_termination (basic_server_url : str ):
983+ async def test_streamablehttp_client_session_termination (basic_server : None , basic_server_url : str ):
976984 """Test client session termination functionality."""
977985
978986 captured_session_id = None
979987
980988 # Create the streamablehttp_client with a custom httpx client to capture headers
981- async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (read_stream , write_stream , get_session_id ):
989+ async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (
990+ read_stream ,
991+ write_stream ,
992+ get_session_id ,
993+ ):
982994 async with ClientSession (read_stream , write_stream ) as session :
983995 # Initialize the session
984996 result = await session .initialize ()
@@ -1009,7 +1021,9 @@ async def test_streamablehttp_client_session_termination(basic_server_url: str):
10091021
10101022
10111023@pytest .mark .anyio
1012- async def test_streamablehttp_client_session_termination_204 (basic_server_url : str , monkeypatch : pytest .MonkeyPatch ):
1024+ async def test_streamablehttp_client_session_termination_204 (
1025+ basic_server : None , basic_server_url : str , monkeypatch : pytest .MonkeyPatch
1026+ ):
10131027 """Test client session termination functionality with a 204 response.
10141028
10151029 This test patches the httpx client to return a 204 response for DELETEs.
@@ -1192,12 +1206,13 @@ async def run_tool():
11921206
11931207 # We should have received the remaining notifications
11941208 assert len (captured_notifications ) == 1
1209+
11951210 assert isinstance (captured_notifications [0 ].root , types .LoggingMessageNotification )
11961211 assert captured_notifications [0 ].root .params .data == "Second notification after lock"
11971212
11981213
11991214@pytest .mark .anyio
1200- async def test_streamablehttp_server_sampling (basic_server_url : str ):
1215+ async def test_streamablehttp_server_sampling (basic_server : None , basic_server_url : str ):
12011216 """Test server-initiated sampling request through streamable HTTP transport."""
12021217 # Variable to track if sampling callback was invoked
12031218 sampling_callback_invoked = False
@@ -1224,7 +1239,11 @@ async def sampling_callback(
12241239 )
12251240
12261241 # Create client with sampling callback
1227- async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (read_stream , write_stream , _ ):
1242+ async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (
1243+ read_stream ,
1244+ write_stream ,
1245+ _ ,
1246+ ):
12281247 async with ClientSession (
12291248 read_stream ,
12301249 write_stream ,
@@ -1284,12 +1303,7 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
12841303 headers_info = {}
12851304 if ctx .request and isinstance (ctx .request , Request ):
12861305 headers_info = dict (ctx .request .headers )
1287- return [
1288- TextContent (
1289- type = "text" ,
1290- text = json .dumps (headers_info ),
1291- )
1292- ]
1306+ return [TextContent (type = "text" , text = json .dumps (headers_info ))]
12931307
12941308 elif name == "echo_context" :
12951309 # Return full context information
@@ -1304,7 +1318,12 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
13041318 context_data ["headers" ] = dict (request .headers )
13051319 context_data ["method" ] = request .method
13061320 context_data ["path" ] = request .url .path
1307- return [TextContent (type = "text" , text = json .dumps (context_data ))]
1321+ return [
1322+ TextContent (
1323+ type = "text" ,
1324+ text = json .dumps (context_data ),
1325+ )
1326+ ]
13081327
13091328 return [TextContent (type = "text" , text = f"Unknown tool: { name } " )]
13101329
@@ -1314,16 +1333,28 @@ def run_context_aware_server(port: int):
13141333 """Run the context-aware test server."""
13151334 server = ContextAwareServerTest ()
13161335
1317- session_manager = StreamableHTTPSessionManager (app = server , event_store = None , json_response = False )
1336+ session_manager = StreamableHTTPSessionManager (
1337+ app = server ,
1338+ event_store = None ,
1339+ json_response = False ,
1340+ )
13181341
13191342 app = Starlette (
13201343 debug = True ,
1321- routes = [Mount ("/mcp" , app = session_manager .handle_request )],
1344+ routes = [
1345+ Mount ("/mcp" , app = session_manager .handle_request ),
1346+ ],
13221347 lifespan = lambda app : session_manager .run (),
13231348 )
13241349
1325- config = uvicorn .Config (app = app , host = "127.0.0.1" , port = port , log_level = "error" )
1326- server_instance = uvicorn .Server (config = config )
1350+ server_instance = uvicorn .Server (
1351+ config = uvicorn .Config (
1352+ app = app ,
1353+ host = "127.0.0.1" ,
1354+ port = port ,
1355+ log_level = "error" ,
1356+ )
1357+ )
13271358 server_instance .run ()
13281359
13291360
@@ -1425,7 +1456,11 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No
14251456@pytest .mark .anyio
14261457async def test_client_includes_protocol_version_header_after_init (context_aware_server : None , basic_server_url : str ):
14271458 """Test that client includes mcp-protocol-version header after initialization."""
1428- async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (read_stream , write_stream , _ ):
1459+ async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (
1460+ read_stream ,
1461+ write_stream ,
1462+ _ ,
1463+ ):
14291464 async with ClientSession (read_stream , write_stream ) as session :
14301465 # Initialize and get the negotiated version
14311466 init_result = await session .initialize ()
@@ -1443,7 +1478,7 @@ async def test_client_includes_protocol_version_header_after_init(context_aware_
14431478 assert headers_data [MCP_PROTOCOL_VERSION_HEADER ] == negotiated_version
14441479
14451480
1446- def test_server_validates_protocol_version_header (basic_server_url : str ):
1481+ def test_server_validates_protocol_version_header (basic_server : None , basic_server_url : str ):
14471482 """Test that server returns 400 Bad Request version if header unsupported or invalid."""
14481483 # First initialize a session to get a valid session ID
14491484 init_response = requests .post (
@@ -1501,7 +1536,7 @@ def test_server_validates_protocol_version_header(basic_server_url: str):
15011536 assert response .status_code == 200
15021537
15031538
1504- def test_server_backwards_compatibility_no_protocol_version (basic_server_url : str ):
1539+ def test_server_backwards_compatibility_no_protocol_version (basic_server : None , basic_server_url : str ):
15051540 """Test server accepts requests without protocol version header."""
15061541 # First initialize a session to get a valid session ID
15071542 init_response = requests .post (
@@ -1531,13 +1566,17 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server_url: st
15311566
15321567
15331568@pytest .mark .anyio
1534- async def test_client_crash_handled (basic_server_url : str ):
1569+ async def test_client_crash_handled (basic_server : None , basic_server_url : str ):
15351570 """Test that cases where the client crashes are handled gracefully."""
15361571
15371572 # Simulate bad client that crashes after init
15381573 async def bad_client ():
15391574 """Client that triggers ClosedResourceError"""
1540- async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (read_stream , write_stream , _ ):
1575+ async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (
1576+ read_stream ,
1577+ write_stream ,
1578+ _ ,
1579+ ):
15411580 async with ClientSession (read_stream , write_stream ) as session :
15421581 await session .initialize ()
15431582 raise Exception ("client crash" )
@@ -1551,7 +1590,11 @@ async def bad_client():
15511590 await anyio .sleep (0.1 )
15521591
15531592 # Try a good client, it should still be able to connect and list tools
1554- async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (read_stream , write_stream , _ ):
1593+ async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (
1594+ read_stream ,
1595+ write_stream ,
1596+ _ ,
1597+ ):
15551598 async with ClientSession (read_stream , write_stream ) as session :
15561599 result = await session .initialize ()
15571600 assert isinstance (result , InitializeResult )
0 commit comments