3939GetSessionIdCallback = Callable [[], str | None ]
4040
4141MCP_SESSION_ID = "mcp-session-id"
42+ MCP_PROTOCOL_VERSION = "MCP-Protocol-Version"
4243LAST_EVENT_ID = "last-event-id"
4344CONTENT_TYPE = "content-type"
4445ACCEPT = "Accept"
@@ -97,17 +98,22 @@ def __init__(
9798 )
9899 self .auth = auth
99100 self .session_id = None
101+ self .protocol_version = None
100102 self .request_headers = {
101103 ACCEPT : f"{ JSON } , { SSE } " ,
102104 CONTENT_TYPE : JSON ,
103105 ** self .headers ,
104106 }
105107
106- def _update_headers_with_session (self , base_headers : dict [str , str ]) -> dict [str , str ]:
107- """Update headers with session ID if available."""
108+ def _update_headers_with_session (
109+ self , base_headers : dict [str , str ]
110+ ) -> dict [str , str ]:
111+ """Update headers with session ID and protocol version if available."""
108112 headers = base_headers .copy ()
109113 if self .session_id :
110114 headers [MCP_SESSION_ID ] = self .session_id
115+ if self .protocol_version :
116+ headers [MCP_PROTOCOL_VERSION ] = self .protocol_version
111117 return headers
112118
113119 def _is_initialization_request (self , message : JSONRPCMessage ) -> bool :
@@ -128,19 +134,36 @@ def _maybe_extract_session_id_from_response(
128134 self .session_id = new_session_id
129135 logger .info (f"Received session ID: { self .session_id } " )
130136
137+ def _maybe_extract_protocol_version_from_message (
138+ self ,
139+ message : JSONRPCMessage ,
140+ ) -> None :
141+ """Extract protocol version from initialization response message."""
142+ if isinstance (message .root , JSONRPCResponse ) and message .root .result :
143+ # Check if result has protocolVersion field
144+ result = message .root .result
145+ if "protocolVersion" in result :
146+ self .protocol_version = result ["protocolVersion" ]
147+ logger .info (f"Negotiated protocol version: { self .protocol_version } " )
148+
131149 async def _handle_sse_event (
132150 self ,
133151 sse : ServerSentEvent ,
134152 read_stream_writer : StreamWriter ,
135153 original_request_id : RequestId | None = None ,
136154 resumption_callback : Callable [[str ], Awaitable [None ]] | None = None ,
155+ is_initialization : bool = False ,
137156 ) -> bool :
138157 """Handle an SSE event, returning True if the response is complete."""
139158 if sse .event == "message" :
140159 try :
141160 message = JSONRPCMessage .model_validate_json (sse .data )
142161 logger .debug (f"SSE message: { message } " )
143162
163+ # Extract protocol version from initialization response
164+ if is_initialization :
165+ self ._maybe_extract_protocol_version_from_message (message )
166+
144167 # If this is a response and we have original_request_id, replace it
145168 if original_request_id is not None and isinstance (message .root , JSONRPCResponse | JSONRPCError ):
146169 message .root .id = original_request_id
@@ -256,9 +279,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
256279 content_type = response .headers .get (CONTENT_TYPE , "" ).lower ()
257280
258281 if content_type .startswith (JSON ):
259- await self ._handle_json_response (response , ctx .read_stream_writer )
282+ await self ._handle_json_response (
283+ response , ctx .read_stream_writer , is_initialization
284+ )
260285 elif content_type .startswith (SSE ):
261- await self ._handle_sse_response (response , ctx )
286+ await self ._handle_sse_response (response , ctx , is_initialization )
262287 else :
263288 await self ._handle_unexpected_content_type (
264289 content_type ,
@@ -269,26 +294,42 @@ async def _handle_json_response(
269294 self ,
270295 response : httpx .Response ,
271296 read_stream_writer : StreamWriter ,
297+ is_initialization : bool = False ,
272298 ) -> None :
273299 """Handle JSON response from the server."""
274300 try :
275301 content = await response .aread ()
276302 message = JSONRPCMessage .model_validate_json (content )
303+
304+ # Extract protocol version from initialization response
305+ if is_initialization :
306+ self ._maybe_extract_protocol_version_from_message (message )
307+
277308 session_message = SessionMessage (message )
278309 await read_stream_writer .send (session_message )
279310 except Exception as exc :
280311 logger .error (f"Error parsing JSON response: { exc } " )
281312 await read_stream_writer .send (exc )
282313
283- async def _handle_sse_response (self , response : httpx .Response , ctx : RequestContext ) -> None :
314+ async def _handle_sse_response (
315+ self ,
316+ response : httpx .Response ,
317+ ctx : RequestContext ,
318+ is_initialization : bool = False ,
319+ ) -> None :
284320 """Handle SSE response from the server."""
285321 try :
286322 event_source = EventSource (response )
287323 async for sse in event_source .aiter_sse ():
288324 is_complete = await self ._handle_sse_event (
289325 sse ,
290326 ctx .read_stream_writer ,
291- resumption_callback = (ctx .metadata .on_resumption_token_update if ctx .metadata else None ),
327+ resumption_callback = (
328+ ctx .metadata .on_resumption_token_update
329+ if ctx .metadata
330+ else None
331+ ),
332+ is_initialization = is_initialization ,
292333 )
293334 # If the SSE event indicates completion, like returning respose/error
294335 # break the loop
0 commit comments