77import anyio .lowlevel
88import httpx
99from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
10- from pydantic import BaseModel
10+ from pydantic import BaseModel , RootModel
1111
1212from mcp .shared .exceptions import McpError
1313from mcp .types import (
2727 ServerResult ,
2828)
2929
30+ RawT = TypeVar ("RawT" )
31+
32+
33+ class ParsedMessage (RootModel [JSONRPCMessage ], Generic [RawT ]):
34+ root : JSONRPCMessage
35+ raw : RawT | None = None
36+
37+ class Config :
38+ arbitrary_types_allowed = True
39+
40+
41+ ReadStream = MemoryObjectReceiveStream [ParsedMessage [RawT ] | Exception ]
42+ ReadStreamWriter = MemoryObjectSendStream [ParsedMessage [RawT ] | Exception ]
43+ WriteStream = MemoryObjectSendStream [ParsedMessage [RawT ]]
44+ WriteStreamReader = MemoryObjectReceiveStream [ParsedMessage [RawT ]]
45+
3046SendRequestT = TypeVar ("SendRequestT" , ClientRequest , ServerRequest )
3147SendResultT = TypeVar ("SendResultT" , ClientResult , ServerResult )
3248SendNotificationT = TypeVar ("SendNotificationT" , ClientNotification , ServerNotification )
@@ -159,8 +175,8 @@ class BaseSession(
159175
160176 def __init__ (
161177 self ,
162- read_stream : MemoryObjectReceiveStream [ JSONRPCMessage | Exception ] ,
163- write_stream : MemoryObjectSendStream [ JSONRPCMessage ] ,
178+ read_stream : ReadStream ,
179+ write_stream : WriteStream ,
164180 receive_request_type : type [ReceiveRequestT ],
165181 receive_notification_type : type [ReceiveNotificationT ],
166182 # If none, reading will never time out
@@ -225,7 +241,9 @@ async def send_request(
225241
226242 # TODO: Support progress callbacks
227243
228- await self ._write_stream .send (JSONRPCMessage (jsonrpc_request ))
244+ await self ._write_stream .send (
245+ ParsedMessage (JSONRPCMessage (jsonrpc_request ), None )
246+ )
229247
230248 try :
231249 with anyio .fail_after (
@@ -261,14 +279,16 @@ async def send_notification(self, notification: SendNotificationT) -> None:
261279 ** notification .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
262280 )
263281
264- await self ._write_stream .send (JSONRPCMessage (jsonrpc_notification ))
282+ await self ._write_stream .send (
283+ ParsedMessage (JSONRPCMessage (jsonrpc_notification ))
284+ )
265285
266286 async def _send_response (
267287 self , request_id : RequestId , response : SendResultT | ErrorData
268288 ) -> None :
269289 if isinstance (response , ErrorData ):
270290 jsonrpc_error = JSONRPCError (jsonrpc = "2.0" , id = request_id , error = response )
271- await self ._write_stream .send (JSONRPCMessage (jsonrpc_error ))
291+ await self ._write_stream .send (ParsedMessage ( JSONRPCMessage (jsonrpc_error ) ))
272292 else :
273293 jsonrpc_response = JSONRPCResponse (
274294 jsonrpc = "2.0" ,
@@ -277,18 +297,23 @@ async def _send_response(
277297 by_alias = True , mode = "json" , exclude_none = True
278298 ),
279299 )
280- await self ._write_stream .send (JSONRPCMessage (jsonrpc_response ))
300+ await self ._write_stream .send (
301+ ParsedMessage (JSONRPCMessage (jsonrpc_response ))
302+ )
281303
282304 async def _receive_loop (self ) -> None :
283305 async with (
284306 self ._read_stream ,
285307 self ._write_stream ,
286308 self ._incoming_message_stream_writer ,
287309 ):
288- async for message in self ._read_stream :
289- if isinstance (message , Exception ):
290- await self ._incoming_message_stream_writer .send (message )
291- elif isinstance (message .root , JSONRPCRequest ):
310+ async for raw_message in self ._read_stream :
311+ if isinstance (raw_message , Exception ):
312+ await self ._incoming_message_stream_writer .send (raw_message )
313+ continue
314+
315+ message = raw_message .root
316+ if isinstance (message .root , JSONRPCRequest ):
292317 validated_request = self ._receive_request_type .model_validate (
293318 message .root .model_dump (
294319 by_alias = True , mode = "json" , exclude_none = True
0 commit comments