66from asyncio import CancelledError
77from collections .abc import AsyncGenerator , AsyncIterator , Callable , Generator
88from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
9- from typing import NamedTuple , TypeAlias , cast
9+ from datetime import datetime , timedelta
10+ from typing import NamedTuple , TypeAlias , TypedDict , cast
1011
1112import janus
13+ from a2a .client import create_text_message_object
1214from a2a .server .agent_execution import AgentExecutor , RequestContext
1315from a2a .server .events import EventQueue , QueueManager
1416from a2a .server .tasks import TaskUpdater
4042from beeai_sdk .server .dependencies import extract_dependencies
4143from beeai_sdk .server .logging import logger
4244from beeai_sdk .server .store .context_store import ContextStore
43- from beeai_sdk .server .utils import cancel_task
45+ from beeai_sdk .server .utils import cancel_task , close_queue
4446
4547AgentFunction : TypeAlias = Callable [[], AsyncGenerator [RunYield , RunYieldResume ]]
4648AgentFunctionFactory : TypeAlias = Callable [
@@ -278,15 +280,25 @@ async def agent_generator():
278280 return decorator
279281
280282
283+ class RunningTask (TypedDict ):
284+ task : asyncio .Task
285+ last_invocation : datetime
286+
287+
281288class Executor (AgentExecutor ):
282289 def __init__ (
283- self , execute_fn : AgentFunctionFactory , queue_manager : QueueManager , context_store : ContextStore
290+ self ,
291+ execute_fn : AgentFunctionFactory ,
292+ queue_manager : QueueManager ,
293+ context_store : ContextStore ,
294+ task_timeout : timedelta ,
284295 ) -> None :
285296 self ._agent_executor_span = execute_fn
286297 self ._queue_manager = queue_manager
287- self ._running_tasks : dict [str , asyncio . Task ] = {}
298+ self ._running_tasks : dict [str , RunningTask ] = {}
288299 self ._cancel_queues : dict [str , EventQueue ] = {}
289300 self ._context_store = context_store
301+ self ._task_timeout = task_timeout
290302
291303 async def _watch_for_cancellation (self , task_id : str , task : asyncio .Task ) -> None :
292304 cancel_queue = await self ._queue_manager .create_or_tap (f"_cancel_{ task_id } " )
@@ -333,7 +345,11 @@ def with_context(message: Message | None = None) -> Message | None:
333345 value : RunYieldResume = None
334346 opened_artifacts : set [str ] = set ()
335347 while True :
348+ # update invocation time
349+ self ._running_tasks [task_updater .task_id ]["last_invocation" ] = datetime .now ()
350+
336351 yielded_value = await agent_generator_fn .asend (value )
352+
337353 match yielded_value :
338354 case str (text ):
339355 await task_updater .update_status (
@@ -438,10 +454,16 @@ def with_context(message: Message | None = None) -> Message | None:
438454 except Exception as ex :
439455 logger .error ("Error when executing agent" , exc_info = ex )
440456 await task_updater .failed (task_updater .new_agent_message (parts = [Part (root = TextPart (text = str (ex )))]))
441- finally :
442- await self ._queue_manager .close (f"_event_{ task_updater .task_id } " )
443- await self ._queue_manager .close (f"_resume_{ task_updater .task_id } " )
457+ finally : # cleanup
444458 await cancel_task (cancellation_task )
459+ is_cancelling = bool (current_task .cancelling ())
460+ try :
461+ async with asyncio .timeout (10 ): # grace period to read all events from queue
462+ await close_queue (self ._queue_manager , f"_event_{ context .task_id } " , immediate = is_cancelling )
463+ await close_queue (self ._queue_manager , f"_resume_{ context .task_id } " , immediate = is_cancelling )
464+ except (TimeoutError , CancelledError ):
465+ await close_queue (self ._queue_manager , f"_event_{ context .task_id } " , immediate = True )
466+ await close_queue (self ._queue_manager , f"_resume_{ context .task_id } " , immediate = True )
445467
446468 async def execute (self , context : RequestContext , event_queue : EventQueue ) -> None :
447469 assert context .message # this is only executed in the context of SendMessage request
@@ -472,10 +494,12 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
472494 resume_queue = resume_queue ,
473495 )
474496
475- self ._running_tasks [context .task_id ] = asyncio .create_task (run_generator )
476- self ._running_tasks [context .task_id ].add_done_callback (
477- lambda _ : self ._running_tasks .pop (context .task_id ) # pyright: ignore [reportArgumentType]
497+ self ._running_tasks [context .task_id ] = RunningTask (
498+ task = asyncio .create_task (run_generator ), last_invocation = datetime .now ()
478499 )
500+ asyncio .create_task (
501+ self ._schedule_run_cleanup (task_id = context .task_id , task_timeout = self ._task_timeout )
502+ ).add_done_callback (lambda _ : ...)
479503
480504 while True :
481505 # Forward messages to local event queue
@@ -490,18 +514,39 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
490514 # When a streaming request is canceled, this executor is canceled first meaning that "cancellation" event
491515 # passed from the agent's long_running_event_queue is not forwarded. Instead of shielding this function,
492516 # we report the cancellation explicitly
517+ await self ._cancel_task (context .task_id )
493518 local_updater = TaskUpdater (event_queue , task_id = context .task_id , context_id = context .context_id )
494519 await local_updater .cancel ()
495520 except Exception as ex :
496521 logger .error ("Error executing agent" , exc_info = ex )
497522 local_updater = TaskUpdater (event_queue , task_id = context .task_id , context_id = context .context_id )
498523 await local_updater .failed (local_updater .new_agent_message (parts = [Part (root = TextPart (text = str (ex )))]))
499524
525+ async def _cancel_task (self , task_id : str ):
526+ if queue := self ._cancel_queues .get (task_id ):
527+ await queue .enqueue_event (create_text_message_object (content = "canceled" ))
528+
529+ async def _schedule_run_cleanup (self , task_id : str , task_timeout : timedelta ):
530+ task = self ._running_tasks .get (task_id )
531+ assert task
532+
533+ try :
534+ while not task ["task" ].done ():
535+ await asyncio .sleep (5 )
536+ if not task ["task" ].done () and task ["last_invocation" ] + task_timeout < datetime .now ():
537+ # Task might be stuck waiting for queue events to be processed
538+ logger .warning (f"Task { task_id } did not finish in { task_timeout } " )
539+ await self ._cancel_task (task_id )
540+ break
541+ except Exception as ex :
542+ logger .error ("Error when cleaning up task" , exc_info = ex )
543+ finally :
544+ self ._running_tasks .pop (task_id )
545+
500546 async def cancel (self , context : RequestContext , event_queue : EventQueue ) -> None :
501547 if not context .task_id or not context .context_id :
502548 raise ValueError ("Task ID and context ID must be set to cancel a task" )
503549 try :
504- if context .current_task and (queue := self ._cancel_queues .get (context .task_id )):
505- await queue .enqueue_event (context .current_task )
550+ await self ._cancel_task (task_id = context .task_id )
506551 finally :
507552 await TaskUpdater (event_queue , task_id = context .task_id , context_id = context .context_id ).cancel ()
0 commit comments