Skip to content

Commit a20df2c

Browse files
committed
fix:Handle propagating agent exceptions
1 parent dfd6e6c commit a20df2c

File tree

4 files changed

+19
-49
lines changed

4 files changed

+19
-49
lines changed

examples/helloworld/agent_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ async def execute(
2424
context: RequestContext,
2525
event_queue: EventQueue,
2626
) -> None:
27+
raise Exception('cancel not supported')
2728
result = await self.agent.invoke()
2829
event_queue.enqueue_event(new_agent_text_message(result))
2930

src/a2a/server/agent_execution/base_agent_executor.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/a2a/server/events/event_consumer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class EventConsumer:
2222

2323
def __init__(self, queue: EventQueue):
2424
self.queue = queue
25+
self._timeout = 0.5
26+
self._exception: BaseException | None = None
2527
logger.debug('EventConsumer initialized')
2628

2729
async def consume_one(self) -> Event:
@@ -45,8 +47,10 @@ async def consume_all(self) -> AsyncGenerator[Event]:
4547
"""Consume all the generated streaming events from the agent."""
4648
logger.debug('Starting to consume all events from the queue.')
4749
while True:
50+
if self._exception:
51+
raise self._exception
4852
try:
49-
event = await self.queue.dequeue_event()
53+
event = await asyncio.wait_for(self.queue.dequeue_event(), timeout=self._timeout)
5054
logger.debug(
5155
f'Dequeued event of type: {type(event)} in consume_all.'
5256
)
@@ -74,5 +78,16 @@ async def consume_all(self) -> AsyncGenerator[Event]:
7478
logger.debug('Stopping event consumption in consume_all.')
7579
self.queue.close()
7680
break
81+
except asyncio.TimeoutError:
82+
# continue polling until there is a final event
83+
continue
7784
except asyncio.QueueShutDown:
7885
break
86+
87+
88+
89+
90+
91+
def agent_task_callback(self, agent_task: asyncio.Task[None]):
92+
if agent_task.exception() is not None:
93+
self._exception = agent_task.exception()

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ async def on_message_send(
138138
await self._register_producer(task_id, producer_task)
139139

140140
consumer = EventConsumer(queue)
141+
producer_task.add_done_callback(consumer.agent_task_callback)
141142

142143
interrupted = False
143144
try:
@@ -192,6 +193,7 @@ async def on_message_send_stream(
192193

193194
try:
194195
consumer = EventConsumer(queue)
196+
producer_task.add_done_callback(consumer.agent_task_callback)
195197
async for event in result_aggregator.consume_and_emit(consumer):
196198
# Now we know we have a Task, register the queue
197199
if isinstance(event, Task):

0 commit comments

Comments
 (0)