Skip to content

Commit c6cca13

Browse files
feat: track protocol version in StreamableHttpTransport
The client now tracks the negotiated protocol version from the server's response headers, enabling version-aware communication between client and server.
1 parent 2314b5a commit c6cca13

File tree

1 file changed

+47
-6
lines changed

1 file changed

+47
-6
lines changed

src/mcp/client/streamable_http.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
GetSessionIdCallback = Callable[[], str | None]
4040

4141
MCP_SESSION_ID = "mcp-session-id"
42+
MCP_PROTOCOL_VERSION = "MCP-Protocol-Version"
4243
LAST_EVENT_ID = "last-event-id"
4344
CONTENT_TYPE = "content-type"
4445
ACCEPT = "Accept"
@@ -97,17 +98,22 @@ def __init__(
9798
)
9899
self.auth = auth
99100
self.session_id = None
101+
self.protocol_version = None
100102
self.request_headers = {
101103
ACCEPT: f"{JSON}, {SSE}",
102104
CONTENT_TYPE: JSON,
103105
**self.headers,
104106
}
105107

106-
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
107-
"""Update headers with session ID if available."""
108+
def _update_headers_with_session(
109+
self, base_headers: dict[str, str]
110+
) -> dict[str, str]:
111+
"""Update headers with session ID and protocol version if available."""
108112
headers = base_headers.copy()
109113
if self.session_id:
110114
headers[MCP_SESSION_ID] = self.session_id
115+
if self.protocol_version:
116+
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
111117
return headers
112118

113119
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
@@ -128,19 +134,36 @@ def _maybe_extract_session_id_from_response(
128134
self.session_id = new_session_id
129135
logger.info(f"Received session ID: {self.session_id}")
130136

137+
def _maybe_extract_protocol_version_from_message(
138+
self,
139+
message: JSONRPCMessage,
140+
) -> None:
141+
"""Extract protocol version from initialization response message."""
142+
if isinstance(message.root, JSONRPCResponse) and message.root.result:
143+
# Check if result has protocolVersion field
144+
result = message.root.result
145+
if "protocolVersion" in result:
146+
self.protocol_version = result["protocolVersion"]
147+
logger.info(f"Negotiated protocol version: {self.protocol_version}")
148+
131149
async def _handle_sse_event(
132150
self,
133151
sse: ServerSentEvent,
134152
read_stream_writer: StreamWriter,
135153
original_request_id: RequestId | None = None,
136154
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
155+
is_initialization: bool = False,
137156
) -> bool:
138157
"""Handle an SSE event, returning True if the response is complete."""
139158
if sse.event == "message":
140159
try:
141160
message = JSONRPCMessage.model_validate_json(sse.data)
142161
logger.debug(f"SSE message: {message}")
143162

163+
# Extract protocol version from initialization response
164+
if is_initialization:
165+
self._maybe_extract_protocol_version_from_message(message)
166+
144167
# If this is a response and we have original_request_id, replace it
145168
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
146169
message.root.id = original_request_id
@@ -256,9 +279,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
256279
content_type = response.headers.get(CONTENT_TYPE, "").lower()
257280

258281
if content_type.startswith(JSON):
259-
await self._handle_json_response(response, ctx.read_stream_writer)
282+
await self._handle_json_response(
283+
response, ctx.read_stream_writer, is_initialization
284+
)
260285
elif content_type.startswith(SSE):
261-
await self._handle_sse_response(response, ctx)
286+
await self._handle_sse_response(response, ctx, is_initialization)
262287
else:
263288
await self._handle_unexpected_content_type(
264289
content_type,
@@ -269,26 +294,42 @@ async def _handle_json_response(
269294
self,
270295
response: httpx.Response,
271296
read_stream_writer: StreamWriter,
297+
is_initialization: bool = False,
272298
) -> None:
273299
"""Handle JSON response from the server."""
274300
try:
275301
content = await response.aread()
276302
message = JSONRPCMessage.model_validate_json(content)
303+
304+
# Extract protocol version from initialization response
305+
if is_initialization:
306+
self._maybe_extract_protocol_version_from_message(message)
307+
277308
session_message = SessionMessage(message)
278309
await read_stream_writer.send(session_message)
279310
except Exception as exc:
280311
logger.error(f"Error parsing JSON response: {exc}")
281312
await read_stream_writer.send(exc)
282313

283-
async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
314+
async def _handle_sse_response(
315+
self,
316+
response: httpx.Response,
317+
ctx: RequestContext,
318+
is_initialization: bool = False,
319+
) -> None:
284320
"""Handle SSE response from the server."""
285321
try:
286322
event_source = EventSource(response)
287323
async for sse in event_source.aiter_sse():
288324
is_complete = await self._handle_sse_event(
289325
sse,
290326
ctx.read_stream_writer,
291-
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
327+
resumption_callback=(
328+
ctx.metadata.on_resumption_token_update
329+
if ctx.metadata
330+
else None
331+
),
332+
is_initialization=is_initialization,
292333
)
293334
# If the SSE event indicates completion, like returning respose/error
294335
# break the loop

0 commit comments

Comments
 (0)