@@ -32,6 +32,7 @@ async def handle_sse(request):
3232"""
3333
3434import logging
35+ from collections .abc import Callable
3536from contextlib import asynccontextmanager
3637from typing import Any
3738from urllib .parse import quote
@@ -41,6 +42,7 @@ async def handle_sse(request):
4142from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
4243from pydantic import ValidationError
4344from sse_starlette import EventSourceResponse
45+ from starlette .background import BackgroundTask
4446from starlette .requests import Request
4547from starlette .responses import Response
4648from starlette .types import Receive , Scope , Send
@@ -79,7 +81,13 @@ def __init__(self, endpoint: str) -> None:
7981 logger .debug (f"SseServerTransport initialized with endpoint: { endpoint } " )
8082
8183 @asynccontextmanager
82- async def connect_sse (self , scope : Scope , receive : Receive , send : Send ):
84+ async def connect_sse (
85+ self ,
86+ scope : Scope ,
87+ receive : Receive ,
88+ send : Send ,
89+ callback : Callable [[], None ] | None = None ,
90+ ):
8391 if scope ["type" ] != "http" :
8492 logger .error ("connect_sse received non-HTTP request" )
8593 raise ValueError ("connect_sse can only handle HTTP requests" )
@@ -120,9 +128,19 @@ async def sse_writer():
120128 }
121129 )
122130
131+ async def _remove_stream_writer () -> None :
132+ await read_stream_writer .aclose ()
133+ await write_stream_reader .aclose ()
134+ del self ._read_stream_writers [session_id ]
135+ if callback :
136+ callback ()
137+ logger .debug (f"Closed SSE session with ID: { session_id } " )
138+
123139 async with anyio .create_task_group () as tg :
124140 response = EventSourceResponse (
125- content = sse_stream_reader , data_sender_callable = sse_writer
141+ content = sse_stream_reader ,
142+ data_sender_callable = sse_writer ,
143+ background = BackgroundTask (_remove_stream_writer ),
126144 )
127145 logger .debug ("Starting SSE response task" )
128146 tg .start_soon (response , scope , receive , send )
0 commit comments