|
65 | 65 | } |
66 | 66 |
|
67 | 67 |
|
| 68 | +# Helper functions |
| 69 | +def extract_protocol_version_from_sse(response: requests.Response) -> str: |
| 70 | + """Extract the negotiated protocol version from an SSE initialization response.""" |
| 71 | + assert response.headers.get("Content-Type") == "text/event-stream" |
| 72 | + for line in response.text.splitlines(): |
| 73 | + if line.startswith("data: "): |
| 74 | + init_data = json.loads(line[6:]) |
| 75 | + return init_data["result"]["protocolVersion"] |
| 76 | + raise ValueError("Could not extract protocol version from SSE response") |
| 77 | + |
| 78 | + |
68 | 79 | # Simple in-memory event store for testing |
69 | 80 | class SimpleEventStore(EventStore): |
70 | 81 | """Simple in-memory event store for testing.""" |
@@ -578,14 +589,7 @@ def test_session_termination(basic_server, basic_server_url): |
578 | 589 | assert response.status_code == 200 |
579 | 590 |
|
580 | 591 | # Extract negotiated protocol version from SSE response |
581 | | - init_data = None |
582 | | - assert response.headers.get("Content-Type") == "text/event-stream" |
583 | | - for line in response.text.splitlines(): |
584 | | - if line.startswith("data: "): |
585 | | - init_data = json.loads(line[6:]) |
586 | | - break |
587 | | - assert init_data is not None |
588 | | - negotiated_version = init_data["result"]["protocolVersion"] |
| 592 | + negotiated_version = extract_protocol_version_from_sse(response) |
589 | 593 |
|
590 | 594 | # Now terminate the session |
591 | 595 | session_id = response.headers.get(MCP_SESSION_ID_HEADER) |
@@ -626,14 +630,7 @@ def test_response(basic_server, basic_server_url): |
626 | 630 | assert response.status_code == 200 |
627 | 631 |
|
628 | 632 | # Extract negotiated protocol version from SSE response |
629 | | - init_data = None |
630 | | - assert response.headers.get("Content-Type") == "text/event-stream" |
631 | | - for line in response.text.splitlines(): |
632 | | - if line.startswith("data: "): |
633 | | - init_data = json.loads(line[6:]) |
634 | | - break |
635 | | - assert init_data is not None |
636 | | - negotiated_version = init_data["result"]["protocolVersion"] |
| 633 | + negotiated_version = extract_protocol_version_from_sse(response) |
637 | 634 |
|
638 | 635 | # Now get the session ID |
639 | 636 | session_id = response.headers.get(MCP_SESSION_ID_HEADER) |
@@ -1574,15 +1571,7 @@ async def test_server_validates_protocol_version_header(basic_server, basic_serv |
1574 | 1571 | ) |
1575 | 1572 |
|
1576 | 1573 | # Test request with valid protocol version (should succeed) |
1577 | | - init_data = None |
1578 | | - assert init_response.headers.get("Content-Type") == "text/event-stream" |
1579 | | - for line in init_response.text.splitlines(): |
1580 | | - if line.startswith("data: "): |
1581 | | - init_data = json.loads(line[6:]) |
1582 | | - break |
1583 | | - |
1584 | | - assert init_data is not None |
1585 | | - negotiated_version = init_data["result"]["protocolVersion"] |
| 1574 | + negotiated_version = extract_protocol_version_from_sse(init_response) |
1586 | 1575 |
|
1587 | 1576 | response = requests.post( |
1588 | 1577 | f"{basic_server_url}/mcp", |
|
0 commit comments