|
65 | 65 | } |
66 | 66 |
|
67 | 67 |
|
| 68 | +# Helper functions |
| 69 | +def extract_protocol_version_from_sse(response: requests.Response) -> str: |
| 70 | + """Extract the negotiated protocol version from an SSE initialization response.""" |
| 71 | + assert response.headers.get("Content-Type") == "text/event-stream" |
| 72 | + for line in response.text.splitlines(): |
| 73 | + if line.startswith("data: "): |
| 74 | + init_data = json.loads(line[6:]) |
| 75 | + return init_data["result"]["protocolVersion"] |
| 76 | + raise ValueError("Could not extract protocol version from SSE response") |
| 77 | + |
| 78 | + |
68 | 79 | # Simple in-memory event store for testing |
69 | 80 | class SimpleEventStore(EventStore): |
70 | 81 | """Simple in-memory event store for testing.""" |
@@ -562,14 +573,7 @@ def test_session_termination(basic_server, basic_server_url): |
562 | 573 | assert response.status_code == 200 |
563 | 574 |
|
564 | 575 | # Extract negotiated protocol version from SSE response |
565 | | - init_data = None |
566 | | - assert response.headers.get("Content-Type") == "text/event-stream" |
567 | | - for line in response.text.splitlines(): |
568 | | - if line.startswith("data: "): |
569 | | - init_data = json.loads(line[6:]) |
570 | | - break |
571 | | - assert init_data is not None |
572 | | - negotiated_version = init_data["result"]["protocolVersion"] |
| 576 | + negotiated_version = extract_protocol_version_from_sse(response) |
573 | 577 |
|
574 | 578 | # Now terminate the session |
575 | 579 | session_id = response.headers.get(MCP_SESSION_ID_HEADER) |
@@ -610,14 +614,7 @@ def test_response(basic_server, basic_server_url): |
610 | 614 | assert response.status_code == 200 |
611 | 615 |
|
612 | 616 | # Extract negotiated protocol version from SSE response |
613 | | - init_data = None |
614 | | - assert response.headers.get("Content-Type") == "text/event-stream" |
615 | | - for line in response.text.splitlines(): |
616 | | - if line.startswith("data: "): |
617 | | - init_data = json.loads(line[6:]) |
618 | | - break |
619 | | - assert init_data is not None |
620 | | - negotiated_version = init_data["result"]["protocolVersion"] |
| 617 | + negotiated_version = extract_protocol_version_from_sse(response) |
621 | 618 |
|
622 | 619 | # Now get the session ID |
623 | 620 | session_id = response.headers.get(MCP_SESSION_ID_HEADER) |
@@ -1506,15 +1503,7 @@ def test_server_validates_protocol_version_header(basic_server, basic_server_url |
1506 | 1503 | ) |
1507 | 1504 |
|
1508 | 1505 | # Test request with valid protocol version (should succeed) |
1509 | | - init_data = None |
1510 | | - assert init_response.headers.get("Content-Type") == "text/event-stream" |
1511 | | - for line in init_response.text.splitlines(): |
1512 | | - if line.startswith("data: "): |
1513 | | - init_data = json.loads(line[6:]) |
1514 | | - break |
1515 | | - |
1516 | | - assert init_data is not None |
1517 | | - negotiated_version = init_data["result"]["protocolVersion"] |
| 1506 | + negotiated_version = extract_protocol_version_from_sse(init_response) |
1518 | 1507 |
|
1519 | 1508 | response = requests.post( |
1520 | 1509 | f"{basic_server_url}/mcp", |
|
0 commit comments