44from collections .abc import AsyncGenerator
55
66from a2a .server .agent_execution import AgentExecutor , RequestContext
7- from a2a .server .events import Event , EventConsumer , EventQueue
7+ from a2a .server .events import (
8+ EventConsumer ,
9+ EventQueue ,
10+ Event ,
11+ QueueManager ,
12+ TaskQueueExists ,
13+ NoTaskQueue ,
14+ InMemoryQueueManager ,
15+ )
816from a2a .server .request_handlers .request_handler import RequestHandler
917from a2a .server .tasks import ResultAggregator , TaskManager , TaskStore
1018from a2a .types import (
@@ -28,13 +36,14 @@ class DefaultRequestHandler(RequestHandler):
2836 """Default request handler for all incoming requests."""
2937
3038 def __init__ (
31- self , agent_executor : AgentExecutor , task_store : TaskStore
39+ self ,
40+ agent_executor : AgentExecutor ,
41+ task_store : TaskStore ,
42+ queue_manager : QueueManager = InMemoryQueueManager (),
3243 ) -> None :
3344 self .agent_executor = agent_executor
3445 self .task_store = task_store
35- # This works for single binary solution. Needs a distributed approach for
36- # true scalable deployment.
37- self ._task_queue : dict [str , EventQueue ] = {}
46+ self ._queue_manager = queue_manager
3847
3948 async def on_get_task (self , params : TaskQueryParams ) -> Task | None :
4049 """Default handler for 'tasks/get'."""
@@ -56,8 +65,10 @@ async def on_cancel_task(self, params: TaskIdParams) -> Task | None:
5665 initial_message = None ,
5766 )
5867 result_aggregator = ResultAggregator (task_manager )
59-
60- queue = EventQueue ()
68+ try :
69+ queue = await self ._queue_manager .tap (task .id )
70+ except :
71+ queue = EventQueue ()
6172 await self .agent_executor .cancel (
6273 RequestContext (
6374 None ,
@@ -95,15 +106,11 @@ async def on_message_send(
95106 task : Task | None = await task_manager .get_task ()
96107 if task :
97108 task = task_manager .update_with_message (params .message , task )
109+ queue = await self ._queue_manager .create_or_tap (task .id )
110+ else :
111+ queue = EventQueue ()
98112 result_aggregator = ResultAggregator (task_manager )
99113 # TODO to manage the non-blocking flows.
100-
101- queue = EventQueue ()
102- # If this is a follow up on an existing task, register the queue now
103- task_id : str | None = task .id if task else None
104- if task_id :
105- self ._task_queue [task_id ] = queue
106-
107114 producer_task = asyncio .create_task (
108115 self ._run_event_stream (
109116 RequestContext (
@@ -127,6 +134,11 @@ async def on_message_send(
127134 return result
128135 finally :
129136 await producer_task
137+ if task :
138+ try :
139+ await self ._queue_manager .close (task .id )
140+ except NoTaskQueue :
141+ pass
130142
131143 async def on_message_send_stream (
132144 self , params : MessageSendParams
@@ -142,14 +154,11 @@ async def on_message_send_stream(
142154
143155 if task :
144156 task = task_manager .update_with_message (params .message , task )
145-
157+ queue = await self ._queue_manager .create_or_tap (task .id )
158+ else :
159+ queue = EventQueue ()
146160 result_aggregator = ResultAggregator (task_manager )
147- queue = EventQueue ()
148-
149161 task_id : str | None = task .id if task else None
150- if task_id :
151- self ._task_queue [task_id ] = queue
152-
153162 producer_task = asyncio .create_task (
154163 self ._run_event_stream (
155164 RequestContext (
@@ -165,13 +174,20 @@ async def on_message_send_stream(
165174 consumer = EventConsumer (queue )
166175 async for event in result_aggregator .consume_and_emit (consumer ):
167176 # Now we know we have a Task, register the queue
168- if isinstance (event , Task ) and event .id not in self ._task_queue :
169- self ._task_queue [event .id ] = queue
170- task_id = event .id
171- yield event
172-
177+ if isinstance (event , Task ):
178+ try :
179+ await self ._queue_manager .add (event .id , queue )
180+ task_id = event .id
181+ except TaskQueueExists :
182+ logging .info (
183+ 'Multiple Task objects created in event stream.' )
184+ yield event
173185 finally :
174186 await producer_task
187+ try :
188+ await self ._queue_manager .close (task_id )
189+ except NoTaskQueue :
190+ pass
175191
176192 async def on_set_task_push_notification_config (
177193 self , request : TaskPushNotificationConfig
@@ -203,12 +219,11 @@ async def on_resubscribe_to_task(
203219
204220 result_aggregator = ResultAggregator (task_manager )
205221
206- # Need to tap the existing queue.
207- if not task .id in self ._task_queue :
222+ try :
223+ queue = await self ._queue_manager .tap (task .id )
224+ except NoTaskQueue :
208225 raise ServerError (error = TaskNotFoundError ())
209- return
210226
211- queue = self ._task_queue [task .id ].tap ()
212227 consumer = EventConsumer (queue )
213228 async for event in result_aggregator .consume_and_emit (consumer ):
214229 yield event
0 commit comments