Skip to content

Commit c436f76

Browse files
authored
fix(sdk): agent block due to incorrect queue close (#1260)
Signed-off-by: Radek Ježek <radek.jezek@ibm.com>
1 parent 9e0e9a7 commit c436f76

File tree

11 files changed

+109
-44
lines changed

11 files changed

+109
-44
lines changed

agents/chat/uv.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

agents/form/uv.lock

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

agents/rag/uv.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

apps/beeai-cli/uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

apps/beeai-sdk/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ readme = "README.md"
66
authors = [{ name = "IBM Corp." }]
77
requires-python = ">=3.11"
88
dependencies = [
9-
"a2a-sdk==0.3.5",
9+
"a2a-sdk==0.3.7",
1010
"objprint>=0.3.0",
1111
"uvicorn>=0.35.0",
1212
"asyncclick>=8.1.8",

apps/beeai-sdk/src/beeai_sdk/server/agent.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from asyncio import CancelledError
77
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator
88
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
9-
from typing import NamedTuple, TypeAlias, cast
9+
from datetime import datetime, timedelta
10+
from typing import NamedTuple, TypeAlias, TypedDict, cast
1011

1112
import janus
13+
from a2a.client import create_text_message_object
1214
from a2a.server.agent_execution import AgentExecutor, RequestContext
1315
from a2a.server.events import EventQueue, QueueManager
1416
from a2a.server.tasks import TaskUpdater
@@ -40,7 +42,7 @@
4042
from beeai_sdk.server.dependencies import extract_dependencies
4143
from beeai_sdk.server.logging import logger
4244
from beeai_sdk.server.store.context_store import ContextStore
43-
from beeai_sdk.server.utils import cancel_task
45+
from beeai_sdk.server.utils import cancel_task, close_queue
4446

4547
AgentFunction: TypeAlias = Callable[[], AsyncGenerator[RunYield, RunYieldResume]]
4648
AgentFunctionFactory: TypeAlias = Callable[
@@ -278,15 +280,25 @@ async def agent_generator():
278280
return decorator
279281

280282

283+
class RunningTask(TypedDict):
284+
task: asyncio.Task
285+
last_invocation: datetime
286+
287+
281288
class Executor(AgentExecutor):
282289
def __init__(
283-
self, execute_fn: AgentFunctionFactory, queue_manager: QueueManager, context_store: ContextStore
290+
self,
291+
execute_fn: AgentFunctionFactory,
292+
queue_manager: QueueManager,
293+
context_store: ContextStore,
294+
task_timeout: timedelta,
284295
) -> None:
285296
self._agent_executor_span = execute_fn
286297
self._queue_manager = queue_manager
287-
self._running_tasks: dict[str, asyncio.Task] = {}
298+
self._running_tasks: dict[str, RunningTask] = {}
288299
self._cancel_queues: dict[str, EventQueue] = {}
289300
self._context_store = context_store
301+
self._task_timeout = task_timeout
290302

291303
async def _watch_for_cancellation(self, task_id: str, task: asyncio.Task) -> None:
292304
cancel_queue = await self._queue_manager.create_or_tap(f"_cancel_{task_id}")
@@ -333,7 +345,11 @@ def with_context(message: Message | None = None) -> Message | None:
333345
value: RunYieldResume = None
334346
opened_artifacts: set[str] = set()
335347
while True:
348+
# update invocation time
349+
self._running_tasks[task_updater.task_id]["last_invocation"] = datetime.now()
350+
336351
yielded_value = await agent_generator_fn.asend(value)
352+
337353
match yielded_value:
338354
case str(text):
339355
await task_updater.update_status(
@@ -438,10 +454,16 @@ def with_context(message: Message | None = None) -> Message | None:
438454
except Exception as ex:
439455
logger.error("Error when executing agent", exc_info=ex)
440456
await task_updater.failed(task_updater.new_agent_message(parts=[Part(root=TextPart(text=str(ex)))]))
441-
finally:
442-
await self._queue_manager.close(f"_event_{task_updater.task_id}")
443-
await self._queue_manager.close(f"_resume_{task_updater.task_id}")
457+
finally: # cleanup
444458
await cancel_task(cancellation_task)
459+
is_cancelling = bool(current_task.cancelling())
460+
try:
461+
async with asyncio.timeout(10): # grace period to read all events from queue
462+
await close_queue(self._queue_manager, f"_event_{context.task_id}", immediate=is_cancelling)
463+
await close_queue(self._queue_manager, f"_resume_{context.task_id}", immediate=is_cancelling)
464+
except (TimeoutError, CancelledError):
465+
await close_queue(self._queue_manager, f"_event_{context.task_id}", immediate=True)
466+
await close_queue(self._queue_manager, f"_resume_{context.task_id}", immediate=True)
445467

446468
async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
447469
assert context.message # this is only executed in the context of SendMessage request
@@ -472,10 +494,12 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
472494
resume_queue=resume_queue,
473495
)
474496

475-
self._running_tasks[context.task_id] = asyncio.create_task(run_generator)
476-
self._running_tasks[context.task_id].add_done_callback(
477-
lambda _: self._running_tasks.pop(context.task_id) # pyright: ignore [reportArgumentType]
497+
self._running_tasks[context.task_id] = RunningTask(
498+
task=asyncio.create_task(run_generator), last_invocation=datetime.now()
478499
)
500+
asyncio.create_task(
501+
self._schedule_run_cleanup(task_id=context.task_id, task_timeout=self._task_timeout)
502+
).add_done_callback(lambda _: ...)
479503

480504
while True:
481505
# Forward messages to local event queue
@@ -490,18 +514,39 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
490514
# When a streaming request is canceled, this executor is canceled first meaning that "cancellation" event
491515
# passed from the agent's long_running_event_queue is not forwarded. Instead of shielding this function,
492516
# we report the cancellation explicitly
517+
await self._cancel_task(context.task_id)
493518
local_updater = TaskUpdater(event_queue, task_id=context.task_id, context_id=context.context_id)
494519
await local_updater.cancel()
495520
except Exception as ex:
496521
logger.error("Error executing agent", exc_info=ex)
497522
local_updater = TaskUpdater(event_queue, task_id=context.task_id, context_id=context.context_id)
498523
await local_updater.failed(local_updater.new_agent_message(parts=[Part(root=TextPart(text=str(ex)))]))
499524

525+
async def _cancel_task(self, task_id: str):
526+
if queue := self._cancel_queues.get(task_id):
527+
await queue.enqueue_event(create_text_message_object(content="canceled"))
528+
529+
async def _schedule_run_cleanup(self, task_id: str, task_timeout: timedelta):
530+
task = self._running_tasks.get(task_id)
531+
assert task
532+
533+
try:
534+
while not task["task"].done():
535+
await asyncio.sleep(5)
536+
if not task["task"].done() and task["last_invocation"] + task_timeout < datetime.now():
537+
# Task might be stuck waiting for queue events to be processed
538+
logger.warning(f"Task {task_id} did not finish in {task_timeout}")
539+
await self._cancel_task(task_id)
540+
break
541+
except Exception as ex:
542+
logger.error("Error when cleaning up task", exc_info=ex)
543+
finally:
544+
self._running_tasks.pop(task_id)
545+
500546
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
501547
if not context.task_id or not context.context_id:
502548
raise ValueError("Task ID and context ID must be set to cancel a task")
503549
try:
504-
if context.current_task and (queue := self._cancel_queues.get(context.task_id)):
505-
await queue.enqueue_event(context.current_task)
550+
await self._cancel_task(task_id=context.task_id)
506551
finally:
507552
await TaskUpdater(event_queue, task_id=context.task_id, context_id=context.context_id).cancel()

apps/beeai-sdk/src/beeai_sdk/server/app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44

5+
from datetime import timedelta
6+
57
from a2a.server.agent_execution import RequestContextBuilder
68
from a2a.server.apps.jsonrpc import A2AFastAPIApplication
79
from a2a.server.apps.rest import A2ARESTFastAPIApplication
@@ -29,13 +31,14 @@ def create_app(
2931
lifespan: Lifespan[AppType] | None = None,
3032
dependencies: list[Depends] | None = None, # pyright: ignore [reportGeneralTypeIssues]
3133
override_interfaces: bool = True,
34+
task_timeout: timedelta = timedelta(minutes=10),
3235
**kwargs,
3336
) -> FastAPI:
3437
queue_manager = queue_manager or InMemoryQueueManager()
3538
task_store = task_store or InMemoryTaskStore()
3639
context_store = context_store or InMemoryContextStore()
3740
http_handler = DefaultRequestHandler(
38-
agent_executor=Executor(agent.execute, queue_manager, context_store=context_store),
41+
agent_executor=Executor(agent.execute, queue_manager, context_store=context_store, task_timeout=task_timeout),
3942
task_store=task_store,
4043
queue_manager=queue_manager,
4144
push_config_store=push_config_store,

apps/beeai-sdk/src/beeai_sdk/server/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import AsyncGenerator, Awaitable, Callable
99
from configparser import RawConfigParser
1010
from contextlib import asynccontextmanager, nullcontext, suppress
11+
from datetime import timedelta
1112
from ssl import CERT_NONE
1213
from typing import IO, Any, Literal
1314
from urllib.parse import urljoin
@@ -65,6 +66,7 @@ async def serve(
6566
task_store: TaskStore | None = None,
6667
context_store: ContextStore | None = None,
6768
queue_manager: QueueManager | None = None,
69+
task_timeout: timedelta = timedelta(minutes=10),
6870
push_config_store: PushNotificationConfigStore | None = None,
6971
push_sender: PushNotificationSender | None = None,
7072
request_context_builder: RequestContextBuilder | None = None,
@@ -162,6 +164,7 @@ async def _lifespan_fn(app: FastAPI) -> AsyncGenerator[None, None]:
162164
queue_manager=queue_manager,
163165
push_config_store=push_config_store,
164166
push_sender=push_sender,
167+
task_timeout=task_timeout,
165168
request_context_builder=request_context_builder,
166169
)
167170

0 commit comments

Comments
 (0)