77import anyio .lowlevel
88import httpx
99from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
10- from pydantic import BaseModel , RootModel
10+ from pydantic import BaseModel
1111from typing_extensions import Self
1212
1313from mcp .shared .exceptions import McpError
2222 JSONRPCNotification ,
2323 JSONRPCRequest ,
2424 JSONRPCResponse ,
25+ MessageFrame ,
2526 RequestParams ,
2627 ServerNotification ,
2728 ServerRequest ,
2829 ServerResult ,
2930)
3031
31- RawT = TypeVar ("RawT" )
32-
33-
34- class MessageFrame (RootModel [JSONRPCMessage ], Generic [RawT ]):
35- root : JSONRPCMessage
36- raw : RawT | None = None
37-
38- class Config :
39- arbitrary_types_allowed = True
40-
41-
42- ReadStream = MemoryObjectReceiveStream [MessageFrame [RawT ] | Exception ]
43- ReadStreamWriter = MemoryObjectSendStream [MessageFrame [RawT ] | Exception ]
44- WriteStream = MemoryObjectSendStream [MessageFrame [RawT ]]
45- WriteStreamReader = MemoryObjectReceiveStream [MessageFrame [RawT ]]
32+ ReadStream = MemoryObjectReceiveStream [MessageFrame | Exception ]
33+ ReadStreamWriter = MemoryObjectSendStream [MessageFrame | Exception ]
34+ WriteStream = MemoryObjectSendStream [MessageFrame ]
35+ WriteStreamReader = MemoryObjectReceiveStream [MessageFrame ]
4636
4737SendRequestT = TypeVar ("SendRequestT" , ClientRequest , ServerRequest )
4838SendResultT = TypeVar ("SendResultT" , ClientResult , ServerResult )
@@ -259,7 +249,7 @@ async def send_request(
259249 # TODO: Support progress callbacks
260250
261251 await self ._write_stream .send (
262- MessageFrame (JSONRPCMessage (jsonrpc_request ), None )
252+ MessageFrame (root = JSONRPCMessage (jsonrpc_request ), raw = None )
263253 )
264254
265255 try :
@@ -297,15 +287,17 @@ async def send_notification(self, notification: SendNotificationT) -> None:
297287 )
298288
299289 await self ._write_stream .send (
300- MessageFrame (JSONRPCMessage (jsonrpc_notification ))
290+ MessageFrame (root = JSONRPCMessage (jsonrpc_notification ), raw = None )
301291 )
302292
303293 async def _send_response (
304294 self , request_id : RequestId , response : SendResultT | ErrorData
305295 ) -> None :
306296 if isinstance (response , ErrorData ):
307297 jsonrpc_error = JSONRPCError (jsonrpc = "2.0" , id = request_id , error = response )
308- await self ._write_stream .send (MessageFrame (JSONRPCMessage (jsonrpc_error )))
298+ await self ._write_stream .send (
299+ MessageFrame (root = JSONRPCMessage (jsonrpc_error ), raw = None )
300+ )
309301 else :
310302 jsonrpc_response = JSONRPCResponse (
311303 jsonrpc = "2.0" ,
@@ -315,7 +307,7 @@ async def _send_response(
315307 ),
316308 )
317309 await self ._write_stream .send (
318- MessageFrame (JSONRPCMessage (jsonrpc_response ))
310+ MessageFrame (root = JSONRPCMessage (jsonrpc_response ), raw = None )
319311 )
320312
321313 async def _receive_loop (self ) -> None :
0 commit comments