1313from abc import ABC , abstractmethod
1414from collections .abc import AsyncGenerator , Awaitable , Callable
1515from contextlib import asynccontextmanager
16+ from dataclasses import dataclass
1617from http import HTTPStatus
17- from typing import Any
1818
1919import anyio
2020from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
6363EventId = str
6464
6565
66+ @dataclass
6667class EventMessage :
6768 """
6869 A JSONRPCMessage with an optional event ID for stream resumability.
6970 """
7071
7172 message : JSONRPCMessage
72- event_id : str | None
73-
74- def __init__ (self , message : JSONRPCMessage , event_id : str | None = None ):
75- self .message = message
76- self .event_id = event_id
73+ event_id : str | None = None
7774
7875
7976EventCallback = Callable [[EventMessage ], Awaitable [None ]]
@@ -226,6 +223,21 @@ def _get_session_id(self, request: Request) -> str | None:
226223 """Extract the session ID from request headers."""
227224 return request .headers .get (MCP_SESSION_ID_HEADER )
228225
226+ def _create_event_data (self , event_message : EventMessage ) -> dict [str , str ]:
227+ """Create event data dictionary from an EventMessage."""
228+ event_data = {
229+ "event" : "message" ,
230+ "data" : event_message .message .model_dump_json (
231+ by_alias = True , exclude_none = True
232+ ),
233+ }
234+
235+ # If an event ID was provided, include it
236+ if event_message .event_id :
237+ event_data ["id" ] = event_message .event_id
238+
239+ return event_data
240+
229241 async def handle_request (self , scope : Scope , receive : Receive , send : Send ) -> None :
230242 """Application entry point that handles all HTTP requests"""
231243 request = Request (scope , receive )
@@ -434,7 +446,7 @@ async def _handle_post_request(
434446 else :
435447 # Create SSE stream
436448 sse_stream_writer , sse_stream_reader = (
437- anyio .create_memory_object_stream [dict [str , Any ]](0 )
449+ anyio .create_memory_object_stream [dict [str , str ]](0 )
438450 )
439451
440452 async def sse_writer ():
@@ -444,17 +456,7 @@ async def sse_writer():
444456 # Process messages from the request-specific stream
445457 async for event_message in request_stream_reader :
446458 # Build the event data
447- event_data = {
448- "event" : "message" ,
449- "data" : event_message .message .model_dump_json (
450- by_alias = True , exclude_none = True
451- ),
452- }
453-
454- # If an event ID was provided, include it
455- if event_message .event_id :
456- event_data ["id" ] = event_message .event_id
457-
459+ event_data = self ._create_event_data (event_message )
458460 await sse_stream_writer .send (event_data )
459461
460462 # If response, remove from pending streams and close
@@ -571,7 +573,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
571573
572574 # Create SSE stream
573575 sse_stream_writer , sse_stream_reader = anyio .create_memory_object_stream [
574- dict [str , Any ]
576+ dict [str , str ]
575577 ](0 )
576578
577579 async def standalone_sse_writer ():
@@ -593,17 +595,7 @@ async def standalone_sse_writer():
593595 # We should NOT receive JSONRPCResponse
594596
595597 # Send the message via SSE
596- event_data = {
597- "event" : "message" ,
598- "data" : event_message .message .model_dump_json (
599- by_alias = True , exclude_none = True
600- ),
601- }
602-
603- # If an event ID was provided, include it in the SSE stream
604- if event_message .event_id :
605- event_data ["id" ] = event_message .event_id
606-
598+ event_data = self ._create_event_data (event_message )
607599 await sse_stream_writer .send (event_data )
608600 except Exception as e :
609601 logger .exception (f"Error in standalone SSE writer: { e } " )
@@ -744,23 +736,16 @@ async def _replay_events(
744736
745737 # Create SSE stream for replay
746738 sse_stream_writer , sse_stream_reader = anyio .create_memory_object_stream [
747- dict [str , Any ]
739+ dict [str , str ]
748740 ](0 )
749741
750742 async def replay_sender ():
751743 try :
752744 async with sse_stream_writer :
753745 # Define an async callback for sending events
754746 async def send_event (event_message : EventMessage ) -> None :
755- await sse_stream_writer .send (
756- {
757- "event" : "message" ,
758- "id" : event_message .event_id ,
759- "data" : event_message .message .model_dump_json (
760- by_alias = True , exclude_none = True
761- ),
762- }
763- )
747+ event_data = self ._create_event_data (event_message )
748+ await sse_stream_writer .send (event_data )
764749
765750 # Replay past events and get the stream ID
766751 stream_id = await event_store .replay_events_after (
@@ -777,16 +762,9 @@ async def send_event(event_message: EventMessage) -> None:
777762 # Forward messages to SSE
778763 async with msg_reader :
779764 async for event_message in msg_reader :
780- event_data = event_message .message .model_dump_json (
781- by_alias = True , exclude_none = True
782- )
783- await sse_stream_writer .send (
784- {
785- "event" : "message" ,
786- "id" : event_message .event_id ,
787- "data" : event_data ,
788- }
789- )
765+ event_data = self ._create_event_data (event_message )
766+
767+ await sse_stream_writer .send (event_data )
790768 except Exception as e :
791769 logger .exception (f"Error in replay sender: { e } " )
792770
0 commit comments