Skip to content

Commit a4793f1

Browse files
feat: track protocol version in StreamableHttpTransport
The client now tracks the negotiated protocol version from the server's response headers, enabling version-aware communication between client and server.
1 parent 2363398 commit a4793f1

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

src/mcp/client/streamable_http.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
GetSessionIdCallback = Callable[[], str | None]
4141

4242
MCP_SESSION_ID = "mcp-session-id"
43+
MCP_PROTOCOL_VERSION = "MCP-Protocol-Version"
4344
LAST_EVENT_ID = "last-event-id"
4445
CONTENT_TYPE = "content-type"
4546
ACCEPT = "Accept"
@@ -100,6 +101,7 @@ def __init__(
100101
self.sse_read_timeout = sse_read_timeout
101102
self.auth = auth
102103
self.session_id: str | None = None
104+
self.protocol_version: str | None = None
103105
self.request_headers = {
104106
ACCEPT: f"{JSON}, {SSE}",
105107
CONTENT_TYPE: JSON,
@@ -109,10 +111,12 @@ def __init__(
109111
def _update_headers_with_session(
110112
self, base_headers: dict[str, str]
111113
) -> dict[str, str]:
112-
"""Update headers with session ID if available."""
114+
"""Update headers with session ID and protocol version if available."""
113115
headers = base_headers.copy()
114116
if self.session_id:
115117
headers[MCP_SESSION_ID] = self.session_id
118+
if self.protocol_version:
119+
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
116120
return headers
117121

118122
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
@@ -139,19 +143,36 @@ def _maybe_extract_session_id_from_response(
139143
self.session_id = new_session_id
140144
logger.info(f"Received session ID: {self.session_id}")
141145

146+
def _maybe_extract_protocol_version_from_message(
147+
self,
148+
message: JSONRPCMessage,
149+
) -> None:
150+
"""Extract protocol version from initialization response message."""
151+
if isinstance(message.root, JSONRPCResponse) and message.root.result:
152+
# Check if result has protocolVersion field
153+
result = message.root.result
154+
if "protocolVersion" in result:
155+
self.protocol_version = result["protocolVersion"]
156+
logger.info(f"Negotiated protocol version: {self.protocol_version}")
157+
142158
async def _handle_sse_event(
143159
self,
144160
sse: ServerSentEvent,
145161
read_stream_writer: StreamWriter,
146162
original_request_id: RequestId | None = None,
147163
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
164+
is_initialization: bool = False,
148165
) -> bool:
149166
"""Handle an SSE event, returning True if the response is complete."""
150167
if sse.event == "message":
151168
try:
152169
message = JSONRPCMessage.model_validate_json(sse.data)
153170
logger.debug(f"SSE message: {message}")
154171

172+
# Extract protocol version from initialization response
173+
if is_initialization:
174+
self._maybe_extract_protocol_version_from_message(message)
175+
155176
# If this is a response and we have original_request_id, replace it
156177
if original_request_id is not None and isinstance(
157178
message.root, JSONRPCResponse | JSONRPCError
@@ -273,9 +294,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
273294
content_type = response.headers.get(CONTENT_TYPE, "").lower()
274295

275296
if content_type.startswith(JSON):
276-
await self._handle_json_response(response, ctx.read_stream_writer)
297+
await self._handle_json_response(
298+
response, ctx.read_stream_writer, is_initialization
299+
)
277300
elif content_type.startswith(SSE):
278-
await self._handle_sse_response(response, ctx)
301+
await self._handle_sse_response(response, ctx, is_initialization)
279302
else:
280303
await self._handle_unexpected_content_type(
281304
content_type,
@@ -286,19 +309,28 @@ async def _handle_json_response(
286309
self,
287310
response: httpx.Response,
288311
read_stream_writer: StreamWriter,
312+
is_initialization: bool = False,
289313
) -> None:
290314
"""Handle JSON response from the server."""
291315
try:
292316
content = await response.aread()
293317
message = JSONRPCMessage.model_validate_json(content)
318+
319+
# Extract protocol version from initialization response
320+
if is_initialization:
321+
self._maybe_extract_protocol_version_from_message(message)
322+
294323
session_message = SessionMessage(message)
295324
await read_stream_writer.send(session_message)
296325
except Exception as exc:
297326
logger.error(f"Error parsing JSON response: {exc}")
298327
await read_stream_writer.send(exc)
299328

300329
async def _handle_sse_response(
301-
self, response: httpx.Response, ctx: RequestContext
330+
self,
331+
response: httpx.Response,
332+
ctx: RequestContext,
333+
is_initialization: bool = False,
302334
) -> None:
303335
"""Handle SSE response from the server."""
304336
try:
@@ -312,6 +344,7 @@ async def _handle_sse_response(
312344
if ctx.metadata
313345
else None
314346
),
347+
is_initialization=is_initialization,
315348
)
316349
# If the SSE event indicates completion, like returning respose/error
317350
# break the loop

0 commit comments

Comments
 (0)