4040GetSessionIdCallback = Callable [[], str | None ]
4141
4242MCP_SESSION_ID = "mcp-session-id"
43+ MCP_PROTOCOL_VERSION = "MCP-Protocol-Version"
4344LAST_EVENT_ID = "last-event-id"
4445CONTENT_TYPE = "content-type"
4546ACCEPT = "Accept"
@@ -100,6 +101,7 @@ def __init__(
100101 self .sse_read_timeout = sse_read_timeout
101102 self .auth = auth
102103 self .session_id : str | None = None
104+ self .protocol_version : str | None = None
103105 self .request_headers = {
104106 ACCEPT : f"{ JSON } , { SSE } " ,
105107 CONTENT_TYPE : JSON ,
@@ -109,10 +111,12 @@ def __init__(
109111 def _update_headers_with_session (
110112 self , base_headers : dict [str , str ]
111113 ) -> dict [str , str ]:
112- """Update headers with session ID if available."""
114+ """Update headers with session ID and protocol version if available."""
113115 headers = base_headers .copy ()
114116 if self .session_id :
115117 headers [MCP_SESSION_ID ] = self .session_id
118+ if self .protocol_version :
119+ headers [MCP_PROTOCOL_VERSION ] = self .protocol_version
116120 return headers
117121
118122 def _is_initialization_request (self , message : JSONRPCMessage ) -> bool :
@@ -139,19 +143,36 @@ def _maybe_extract_session_id_from_response(
139143 self .session_id = new_session_id
140144 logger .info (f"Received session ID: { self .session_id } " )
141145
146+ def _maybe_extract_protocol_version_from_message (
147+ self ,
148+ message : JSONRPCMessage ,
149+ ) -> None :
150+ """Extract protocol version from initialization response message."""
151+ if isinstance (message .root , JSONRPCResponse ) and message .root .result :
152+ # Check if result has protocolVersion field
153+ result = message .root .result
154+ if "protocolVersion" in result :
155+ self .protocol_version = result ["protocolVersion" ]
156+ logger .info (f"Negotiated protocol version: { self .protocol_version } " )
157+
142158 async def _handle_sse_event (
143159 self ,
144160 sse : ServerSentEvent ,
145161 read_stream_writer : StreamWriter ,
146162 original_request_id : RequestId | None = None ,
147163 resumption_callback : Callable [[str ], Awaitable [None ]] | None = None ,
164+ is_initialization : bool = False ,
148165 ) -> bool :
149166 """Handle an SSE event, returning True if the response is complete."""
150167 if sse .event == "message" :
151168 try :
152169 message = JSONRPCMessage .model_validate_json (sse .data )
153170 logger .debug (f"SSE message: { message } " )
154171
172+ # Extract protocol version from initialization response
173+ if is_initialization :
174+ self ._maybe_extract_protocol_version_from_message (message )
175+
155176 # If this is a response and we have original_request_id, replace it
156177 if original_request_id is not None and isinstance (
157178 message .root , JSONRPCResponse | JSONRPCError
@@ -273,9 +294,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
273294 content_type = response .headers .get (CONTENT_TYPE , "" ).lower ()
274295
275296 if content_type .startswith (JSON ):
276- await self ._handle_json_response (response , ctx .read_stream_writer )
297+ await self ._handle_json_response (
298+ response , ctx .read_stream_writer , is_initialization
299+ )
277300 elif content_type .startswith (SSE ):
278- await self ._handle_sse_response (response , ctx )
301+ await self ._handle_sse_response (response , ctx , is_initialization )
279302 else :
280303 await self ._handle_unexpected_content_type (
281304 content_type ,
@@ -286,19 +309,28 @@ async def _handle_json_response(
286309 self ,
287310 response : httpx .Response ,
288311 read_stream_writer : StreamWriter ,
312+ is_initialization : bool = False ,
289313 ) -> None :
290314 """Handle JSON response from the server."""
291315 try :
292316 content = await response .aread ()
293317 message = JSONRPCMessage .model_validate_json (content )
318+
319+ # Extract protocol version from initialization response
320+ if is_initialization :
321+ self ._maybe_extract_protocol_version_from_message (message )
322+
294323 session_message = SessionMessage (message )
295324 await read_stream_writer .send (session_message )
296325 except Exception as exc :
297326 logger .error (f"Error parsing JSON response: { exc } " )
298327 await read_stream_writer .send (exc )
299328
300329 async def _handle_sse_response (
301- self , response : httpx .Response , ctx : RequestContext
330+ self ,
331+ response : httpx .Response ,
332+ ctx : RequestContext ,
333+ is_initialization : bool = False ,
302334 ) -> None :
303335 """Handle SSE response from the server."""
304336 try :
@@ -312,6 +344,7 @@ async def _handle_sse_response(
312344 if ctx .metadata
313345 else None
314346 ),
347+ is_initialization = is_initialization ,
315348 )
316349 # If the SSE event indicates completion, like returning respose/error
317350 # break the loop
0 commit comments