Skip to content

Commit 4ce7e4c

Browse files
feat: implement MCP-Protocol-Version header validation in server
- Add MCP_PROTOCOL_VERSION_HEADER constant - Add _validate_protocol_version method to check header presence and validity - Validate protocol version for all non-initialization requests (POST, GET, DELETE) - Return 400 Bad Request for missing or invalid protocol versions - Update tests to include MCP-Protocol-Version header in requests - Fix test_streamablehttp_client_resumption to pass protocol version when resuming This implements the server-side validation required by the spec change that mandates clients include the negotiated protocol version in all subsequent HTTP requests after initialization. Github-Issue: #548
1 parent c6cca13 commit 4ce7e4c

File tree

2 files changed

+101
-12
lines changed

2 files changed

+101
-12
lines changed

src/mcp/server/streamable_http.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from starlette.types import Receive, Scope, Send
2626

2727
from mcp.shared.message import ServerMessageMetadata, SessionMessage
28+
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
2829
from mcp.types import (
2930
INTERNAL_ERROR,
3031
INVALID_PARAMS,
@@ -45,6 +46,7 @@
4546

4647
# Header names
4748
MCP_SESSION_ID_HEADER = "mcp-session-id"
49+
MCP_PROTOCOL_VERSION_HEADER = "MCP-Protocol-Version"
4850
LAST_EVENT_ID_HEADER = "last-event-id"
4951

5052
# Content types
@@ -353,9 +355,10 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
353355
)
354356
await response(scope, receive, send)
355357
return
356-
# For non-initialization requests, validate the session
357358
elif not await self._validate_session(request, send):
358359
return
360+
elif not await self._validate_protocol_version(request, send):
361+
return
359362

360363
# For notifications and responses only, return 202 Accepted
361364
if not isinstance(message.root, JSONRPCRequest):
@@ -515,6 +518,9 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
515518

516519
if not await self._validate_session(request, send):
517520
return
521+
if not await self._validate_protocol_version(request, send):
522+
return
523+
518524
# Handle resumability: check for Last-Event-ID header
519525
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
520526
await self._replay_events(last_event_id, request, send)
@@ -595,6 +601,8 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
595601

596602
if not await self._validate_session(request, send):
597603
return
604+
if not await self._validate_protocol_version(request, send):
605+
return
598606

599607
await self._terminate_session()
600608

@@ -682,7 +690,34 @@ async def _validate_session(self, request: Request, send: Send) -> bool:
682690

683691
return True
684692

685-
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
693+
async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
694+
"""Validate the protocol version header in the request."""
695+
# Get the protocol version from the request headers
696+
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
697+
698+
# If no protocol version provided, return error
699+
if not protocol_version:
700+
response = self._create_error_response(
701+
"Bad Request: Missing MCP-Protocol-Version header",
702+
HTTPStatus.BAD_REQUEST,
703+
)
704+
await response(request.scope, request.receive, send)
705+
return False
706+
707+
# Check if the protocol version is supported
708+
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
709+
response = self._create_error_response(
710+
f"Bad Request: Unsupported protocol version: {protocol_version}",
711+
HTTPStatus.BAD_REQUEST,
712+
)
713+
await response(request.scope, request.receive, send)
714+
return False
715+
716+
return True
717+
718+
async def _replay_events(
719+
self, last_event_id: str, request: Request, send: Send
720+
) -> None:
686721
"""
687722
Replays events that would have been sent after the specified event ID.
688723
Only used when resumability is enabled.

tests/shared/test_streamable_http.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from mcp.client.streamable_http import streamablehttp_client
2727
from mcp.server import Server
2828
from mcp.server.streamable_http import (
29+
MCP_PROTOCOL_VERSION_HEADER,
2930
MCP_SESSION_ID_HEADER,
3031
SESSION_ID_PATTERN,
3132
EventCallback,
@@ -560,11 +561,24 @@ def test_session_termination(basic_server, basic_server_url):
560561
)
561562
assert response.status_code == 200
562563

564+
# Extract negotiated protocol version from SSE response
565+
init_data = None
566+
assert response.headers.get("Content-Type") == "text/event-stream"
567+
for line in response.text.splitlines():
568+
if line.startswith("data: "):
569+
init_data = json.loads(line[6:])
570+
break
571+
assert init_data is not None
572+
negotiated_version = init_data["result"]["protocolVersion"]
573+
563574
# Now terminate the session
564575
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
565576
response = requests.delete(
566577
f"{basic_server_url}/mcp",
567-
headers={MCP_SESSION_ID_HEADER: session_id},
578+
headers={
579+
MCP_SESSION_ID_HEADER: session_id,
580+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
581+
},
568582
)
569583
assert response.status_code == 200
570584

@@ -595,16 +609,27 @@ def test_response(basic_server, basic_server_url):
595609
)
596610
assert response.status_code == 200
597611

598-
# Now terminate the session
612+
# Extract negotiated protocol version from SSE response
613+
init_data = None
614+
assert response.headers.get("Content-Type") == "text/event-stream"
615+
for line in response.text.splitlines():
616+
if line.startswith("data: "):
617+
init_data = json.loads(line[6:])
618+
break
619+
assert init_data is not None
620+
negotiated_version = init_data["result"]["protocolVersion"]
621+
622+
# Now get the session ID
599623
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
600624

601-
# Try to use the terminated session
625+
# Try to use the session with proper headers
602626
tools_response = requests.post(
603627
mcp_url,
604628
headers={
605629
"Accept": "application/json, text/event-stream",
606630
"Content-Type": "application/json",
607631
MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier
632+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
608633
},
609634
json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"},
610635
stream=True,
@@ -646,12 +671,23 @@ def test_get_sse_stream(basic_server, basic_server_url):
646671
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
647672
assert session_id is not None
648673

674+
# Extract negotiated protocol version from SSE response
675+
init_data = None
676+
assert init_response.headers.get("Content-Type") == "text/event-stream"
677+
for line in init_response.text.splitlines():
678+
if line.startswith("data: "):
679+
init_data = json.loads(line[6:])
680+
break
681+
assert init_data is not None
682+
negotiated_version = init_data["result"]["protocolVersion"]
683+
649684
# Now attempt to establish an SSE stream via GET
650685
get_response = requests.get(
651686
mcp_url,
652687
headers={
653688
"Accept": "text/event-stream",
654689
MCP_SESSION_ID_HEADER: session_id,
690+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
655691
},
656692
stream=True,
657693
)
@@ -666,6 +702,7 @@ def test_get_sse_stream(basic_server, basic_server_url):
666702
headers={
667703
"Accept": "text/event-stream",
668704
MCP_SESSION_ID_HEADER: session_id,
705+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
669706
},
670707
stream=True,
671708
)
@@ -694,11 +731,22 @@ def test_get_validation(basic_server, basic_server_url):
694731
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
695732
assert session_id is not None
696733

734+
# Extract negotiated protocol version from SSE response
735+
init_data = None
736+
assert init_response.headers.get("Content-Type") == "text/event-stream"
737+
for line in init_response.text.splitlines():
738+
if line.startswith("data: "):
739+
init_data = json.loads(line[6:])
740+
break
741+
assert init_data is not None
742+
negotiated_version = init_data["result"]["protocolVersion"]
743+
697744
# Test without Accept header
698745
response = requests.get(
699746
mcp_url,
700747
headers={
701748
MCP_SESSION_ID_HEADER: session_id,
749+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
702750
},
703751
stream=True,
704752
)
@@ -711,6 +759,7 @@ def test_get_validation(basic_server, basic_server_url):
711759
headers={
712760
"Accept": "application/json",
713761
MCP_SESSION_ID_HEADER: session_id,
762+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
714763
},
715764
)
716765
assert response.status_code == 406
@@ -1004,6 +1053,7 @@ async def test_streamablehttp_client_resumption(event_server):
10041053
captured_resumption_token = None
10051054
captured_notifications = []
10061055
tool_started = False
1056+
captured_protocol_version = None
10071057

10081058
async def message_handler(
10091059
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
@@ -1032,6 +1082,8 @@ async def on_resumption_token_update(token: str) -> None:
10321082
assert isinstance(result, InitializeResult)
10331083
captured_session_id = get_session_id()
10341084
assert captured_session_id is not None
1085+
# Capture the negotiated protocol version
1086+
captured_protocol_version = result.protocolVersion
10351087

10361088
# Start a long-running tool in a task
10371089
async with anyio.create_task_group() as tg:
@@ -1064,10 +1116,12 @@ async def run_tool():
10641116
captured_notifications_pre = captured_notifications.copy()
10651117
captured_notifications = []
10661118

1067-
# Now resume the session with the same mcp-session-id
1119+
# Now resume the session with the same mcp-session-id and protocol version
10681120
headers = {}
10691121
if captured_session_id:
10701122
headers[MCP_SESSION_ID_HEADER] = captured_session_id
1123+
if captured_protocol_version:
1124+
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
10711125

10721126
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
10731127
read_stream,
@@ -1413,7 +1467,7 @@ def test_server_validates_protocol_version_header(basic_server, basic_server_url
14131467
)
14141468
assert response.status_code == 400
14151469
assert (
1416-
"MCP-Protocol-Version" in response.text
1470+
MCP_PROTOCOL_VERSION_HEADER in response.text
14171471
or "protocol version" in response.text.lower()
14181472
)
14191473

@@ -1424,13 +1478,13 @@ def test_server_validates_protocol_version_header(basic_server, basic_server_url
14241478
"Accept": "application/json, text/event-stream",
14251479
"Content-Type": "application/json",
14261480
MCP_SESSION_ID_HEADER: session_id,
1427-
"MCP-Protocol-Version": "invalid-version",
1481+
MCP_PROTOCOL_VERSION_HEADER: "invalid-version",
14281482
},
14291483
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"},
14301484
)
14311485
assert response.status_code == 400
14321486
assert (
1433-
"MCP-Protocol-Version" in response.text
1487+
MCP_PROTOCOL_VERSION_HEADER in response.text
14341488
or "protocol version" in response.text.lower()
14351489
)
14361490

@@ -1441,13 +1495,13 @@ def test_server_validates_protocol_version_header(basic_server, basic_server_url
14411495
"Accept": "application/json, text/event-stream",
14421496
"Content-Type": "application/json",
14431497
MCP_SESSION_ID_HEADER: session_id,
1444-
"MCP-Protocol-Version": "1999-01-01", # Very old unsupported version
1498+
MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version
14451499
},
14461500
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"},
14471501
)
14481502
assert response.status_code == 400
14491503
assert (
1450-
"MCP-Protocol-Version" in response.text
1504+
MCP_PROTOCOL_VERSION_HEADER in response.text
14511505
or "protocol version" in response.text.lower()
14521506
)
14531507

@@ -1468,7 +1522,7 @@ def test_server_validates_protocol_version_header(basic_server, basic_server_url
14681522
"Accept": "application/json, text/event-stream",
14691523
"Content-Type": "application/json",
14701524
MCP_SESSION_ID_HEADER: session_id,
1471-
"MCP-Protocol-Version": negotiated_version,
1525+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
14721526
},
14731527
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"},
14741528
)

0 commit comments

Comments
 (0)