77import anyio .lowlevel
88import httpx
99from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
10- from pydantic import BaseModel
10+ from pydantic import BaseModel , RootModel
1111from typing_extensions import Self
1212
1313from mcp .shared .exceptions import McpError
2828 ServerResult ,
2929)
3030
31+ RawT = TypeVar ("RawT" )
32+
33+
34+ class ParsedMessage (RootModel [JSONRPCMessage ], Generic [RawT ]):
35+ root : JSONRPCMessage
36+ raw : RawT | None = None
37+
38+ class Config :
39+ arbitrary_types_allowed = True
40+
41+
42+ ReadStream = MemoryObjectReceiveStream [ParsedMessage [RawT ] | Exception ]
43+ ReadStreamWriter = MemoryObjectSendStream [ParsedMessage [RawT ] | Exception ]
44+ WriteStream = MemoryObjectSendStream [ParsedMessage [RawT ]]
45+ WriteStreamReader = MemoryObjectReceiveStream [ParsedMessage [RawT ]]
46+
3147SendRequestT = TypeVar ("SendRequestT" , ClientRequest , ServerRequest )
3248SendResultT = TypeVar ("SendResultT" , ClientResult , ServerResult )
3349SendNotificationT = TypeVar ("SendNotificationT" , ClientNotification , ServerNotification )
@@ -165,8 +181,8 @@ class BaseSession(
165181
166182 def __init__ (
167183 self ,
168- read_stream : MemoryObjectReceiveStream [ JSONRPCMessage | Exception ] ,
169- write_stream : MemoryObjectSendStream [ JSONRPCMessage ] ,
184+ read_stream : ReadStream ,
185+ write_stream : WriteStream ,
170186 receive_request_type : type [ReceiveRequestT ],
171187 receive_notification_type : type [ReceiveNotificationT ],
172188 # If none, reading will never time out
@@ -242,7 +258,9 @@ async def send_request(
242258
243259 # TODO: Support progress callbacks
244260
245- await self ._write_stream .send (JSONRPCMessage (jsonrpc_request ))
261+ await self ._write_stream .send (
262+ ParsedMessage (JSONRPCMessage (jsonrpc_request ), None )
263+ )
246264
247265 try :
248266 with anyio .fail_after (
@@ -278,14 +296,16 @@ async def send_notification(self, notification: SendNotificationT) -> None:
278296 ** notification .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
279297 )
280298
281- await self ._write_stream .send (JSONRPCMessage (jsonrpc_notification ))
299+ await self ._write_stream .send (
300+ ParsedMessage (JSONRPCMessage (jsonrpc_notification ))
301+ )
282302
283303 async def _send_response (
284304 self , request_id : RequestId , response : SendResultT | ErrorData
285305 ) -> None :
286306 if isinstance (response , ErrorData ):
287307 jsonrpc_error = JSONRPCError (jsonrpc = "2.0" , id = request_id , error = response )
288- await self ._write_stream .send (JSONRPCMessage (jsonrpc_error ))
308+ await self ._write_stream .send (ParsedMessage ( JSONRPCMessage (jsonrpc_error ) ))
289309 else :
290310 jsonrpc_response = JSONRPCResponse (
291311 jsonrpc = "2.0" ,
@@ -294,18 +314,23 @@ async def _send_response(
294314 by_alias = True , mode = "json" , exclude_none = True
295315 ),
296316 )
297- await self ._write_stream .send (JSONRPCMessage (jsonrpc_response ))
317+ await self ._write_stream .send (
318+ ParsedMessage (JSONRPCMessage (jsonrpc_response ))
319+ )
298320
299321 async def _receive_loop (self ) -> None :
300322 async with (
301323 self ._read_stream ,
302324 self ._write_stream ,
303325 self ._incoming_message_stream_writer ,
304326 ):
305- async for message in self ._read_stream :
306- if isinstance (message , Exception ):
307- await self ._incoming_message_stream_writer .send (message )
308- elif isinstance (message .root , JSONRPCRequest ):
327+ async for raw_message in self ._read_stream :
328+ if isinstance (raw_message , Exception ):
329+ await self ._incoming_message_stream_writer .send (raw_message )
330+ continue
331+
332+ message = raw_message .root
333+ if isinstance (message .root , JSONRPCRequest ):
309334 validated_request = self ._receive_request_type .model_validate (
310335 message .root .model_dump (
311336 by_alias = True , mode = "json" , exclude_none = True
0 commit comments