Skip to content

Commit 370b993

Browse files
refactor: extract protocol version parsing to helper function
- Add extract_protocol_version_from_sse helper function to reduce code duplication - Replace repeated protocol version extraction logic in 5 test functions - Fix line length issues in docstrings to comply with 88 char limit This improves test maintainability by centralizing the SSE response parsing logic.
1 parent 4ce7e4c commit 370b993

File tree

1 file changed

+14
-25
lines changed

1 file changed

+14
-25
lines changed

tests/shared/test_streamable_http.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@
6565
}
6666

6767

68+
# Helper functions
69+
def extract_protocol_version_from_sse(response: requests.Response) -> str:
70+
"""Extract the negotiated protocol version from an SSE initialization response."""
71+
assert response.headers.get("Content-Type") == "text/event-stream"
72+
for line in response.text.splitlines():
73+
if line.startswith("data: "):
74+
init_data = json.loads(line[6:])
75+
return init_data["result"]["protocolVersion"]
76+
raise ValueError("Could not extract protocol version from SSE response")
77+
78+
6879
# Simple in-memory event store for testing
6980
class SimpleEventStore(EventStore):
7081
"""Simple in-memory event store for testing."""
@@ -562,14 +573,7 @@ def test_session_termination(basic_server, basic_server_url):
562573
assert response.status_code == 200
563574

564575
# Extract negotiated protocol version from SSE response
565-
init_data = None
566-
assert response.headers.get("Content-Type") == "text/event-stream"
567-
for line in response.text.splitlines():
568-
if line.startswith("data: "):
569-
init_data = json.loads(line[6:])
570-
break
571-
assert init_data is not None
572-
negotiated_version = init_data["result"]["protocolVersion"]
576+
negotiated_version = extract_protocol_version_from_sse(response)
573577

574578
# Now terminate the session
575579
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
@@ -610,14 +614,7 @@ def test_response(basic_server, basic_server_url):
610614
assert response.status_code == 200
611615

612616
# Extract negotiated protocol version from SSE response
613-
init_data = None
614-
assert response.headers.get("Content-Type") == "text/event-stream"
615-
for line in response.text.splitlines():
616-
if line.startswith("data: "):
617-
init_data = json.loads(line[6:])
618-
break
619-
assert init_data is not None
620-
negotiated_version = init_data["result"]["protocolVersion"]
617+
negotiated_version = extract_protocol_version_from_sse(response)
621618

622619
# Now get the session ID
623620
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
@@ -1506,15 +1503,7 @@ def test_server_validates_protocol_version_header(basic_server, basic_server_url
15061503
)
15071504

15081505
# Test request with valid protocol version (should succeed)
1509-
init_data = None
1510-
assert init_response.headers.get("Content-Type") == "text/event-stream"
1511-
for line in init_response.text.splitlines():
1512-
if line.startswith("data: "):
1513-
init_data = json.loads(line[6:])
1514-
break
1515-
1516-
assert init_data is not None
1517-
negotiated_version = init_data["result"]["protocolVersion"]
1506+
negotiated_version = extract_protocol_version_from_sse(init_response)
15181507

15191508
response = requests.post(
15201509
f"{basic_server_url}/mcp",

0 commit comments

Comments
 (0)