Skip to content

Commit 129db8b

Browse files
committed
Fixes for types and random issues
Change-Id: I5acd71e8903825bd31b9d483471b4a3cf0f7e340
1 parent 2ce2bdb commit 129db8b

File tree

5 files changed

+29
-31
lines changed

5 files changed

+29
-31
lines changed

examples/helloworld/agent_executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ def __init__(self):
2424
@override
2525
async def execute(
2626
self,
27-
request: RequestContext,
27+
context: RequestContext,
2828
event_queue: EventQueue,
2929
) -> None:
3030
result = await self.agent.invoke()
3131
event_queue.enqueue_event(new_agent_text_message(result))
3232

3333
@override
3434
async def cancel(
35-
self, request: RequestContext, event_queue: EventQueue
36-
) -> Task | None:
35+
self, context: RequestContext, event_queue: EventQueue
36+
) -> None:
3737
raise Exception('cancel not supported')

examples/langgraph/agent_executor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from a2a.server.agent_execution import AgentExecutor, RequestContext
55
from a2a.server.events.event_queue import EventQueue
66
from a2a.types import (
7-
Task,
87
TaskArtifactUpdateEvent,
98
TaskState,
109
TaskStatus,
@@ -93,6 +92,6 @@ async def execute(
9392

9493
@override
9594
async def cancel(
96-
self, request: RequestContext, event_queue: EventQueue
97-
) -> Task | None:
95+
self, context: RequestContext, event_queue: EventQueue
96+
) -> None:
9897
raise Exception('cancel not supported')

src/a2a/server/events/event_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,4 @@ def close(self):
6868
"""Closes the queue for future push events."""
6969
self.queue.shutdown()
7070
for child in self._children:
71-
child.shutdown()
71+
child.close()

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import logging
34

45
from collections.abc import AsyncGenerator
@@ -39,11 +40,11 @@ def __init__(
3940
self,
4041
agent_executor: AgentExecutor,
4142
task_store: TaskStore,
42-
queue_manager: QueueManager = InMemoryQueueManager(),
43+
queue_manager: QueueManager | None = None,
4344
) -> None:
4445
self.agent_executor = agent_executor
4546
self.task_store = task_store
46-
self._queue_manager = queue_manager
47+
self._queue_manager = queue_manager or InMemoryQueueManager()
4748

4849
async def on_get_task(self, params: TaskQueryParams) -> Task | None:
4950
"""Default handler for 'tasks/get'."""
@@ -65,10 +66,11 @@ async def on_cancel_task(self, params: TaskIdParams) -> Task | None:
6566
initial_message=None,
6667
)
6768
result_aggregator = ResultAggregator(task_manager)
68-
try:
69-
queue = await self._queue_manager.tap(task.id)
70-
except:
69+
70+
queue = await self._queue_manager.tap(task.id)
71+
if not queue:
7172
queue = EventQueue()
73+
7274
await self.agent_executor.cancel(
7375
RequestContext(
7476
None,
@@ -134,11 +136,9 @@ async def on_message_send(
134136
return result
135137
finally:
136138
await producer_task
137-
if task:
138-
try:
139-
await self._queue_manager.close(task.id)
140-
except NoTaskQueue:
141-
pass
139+
if task:
140+
with contextlib.suppress(NoTaskQueue):
141+
await self._queue_manager.close(task.id)
142142

143143
async def on_message_send_stream(
144144
self, params: MessageSendParams
@@ -182,13 +182,12 @@ async def on_message_send_stream(
182182
logging.info(
183183
'Multiple Task objects created in event stream.'
184184
)
185-
yield event
185+
yield event
186186
finally:
187187
await producer_task
188-
try:
189-
await self._queue_manager.close(task_id)
190-
except NoTaskQueue:
191-
pass
188+
if task_id:
189+
with contextlib.suppress(NoTaskQueue):
190+
await self._queue_manager.close(task_id)
192191

193192
async def on_set_task_push_notification_config(
194193
self, params: TaskPushNotificationConfig
@@ -209,7 +208,6 @@ async def on_resubscribe_to_task(
209208
task: Task | None = await self.task_store.get(params.id)
210209
if not task:
211210
raise ServerError(error=TaskNotFoundError())
212-
return
213211

214212
task_manager = TaskManager(
215213
task_id=task.id,
@@ -220,10 +218,9 @@ async def on_resubscribe_to_task(
220218

221219
result_aggregator = ResultAggregator(task_manager)
222220

223-
try:
224-
queue = await self._queue_manager.tap(task.id)
225-
except NoTaskQueue as e:
226-
raise ServerError(error=TaskNotFoundError()) from e
221+
queue = await self._queue_manager.tap(task.id)
222+
if not queue:
223+
raise ServerError(error=TaskNotFoundError())
227224

228225
consumer = EventConsumer(queue)
229226
async for event in result_aggregator.consume_and_emit(consumer):

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
DefaultRequestHandler,
99
JSONRPCHandler,
1010
)
11+
from a2a.server.events import (
12+
QueueManager,
13+
)
1114
from a2a.server.tasks import TaskStore
1215
from a2a.types import (
1316
GetTaskRequest,
@@ -375,13 +378,12 @@ async def test_on_resubscribe_existing_task_success(
375378
) -> None:
376379
mock_agent_executor = AsyncMock(spec=AgentExecutor)
377380
mock_task_store = AsyncMock(spec=TaskStore)
378-
mock_queue = AsyncMock(spec=EventQueue)
381+
mock_queue_manager = AsyncMock(spec=QueueManager)
379382
request_handler = DefaultRequestHandler(
380-
mock_agent_executor, mock_task_store
383+
mock_agent_executor, mock_task_store, mock_queue_manager
381384
)
382385
handler = JSONRPCHandler(None, request_handler)
383386
mock_task = Task(**MINIMAL_TASK, history=[])
384-
request_handler._task_queue[mock_task.id] = mock_queue
385387
events: list[Any] = [
386388
TaskArtifactUpdateEvent(
387389
taskId='task_123',
@@ -407,7 +409,7 @@ async def streaming_coro():
407409
return_value=streaming_coro(),
408410
):
409411
mock_task_store.get.return_value = mock_task
410-
mock_queue.tap.return_value = EventQueue()
412+
mock_queue_manager.tap.return_value = EventQueue()
411413
request = TaskResubscriptionRequest(
412414
id='1', params=TaskIdParams(id=mock_task.id)
413415
)

0 commit comments

Comments
 (0)