Skip to content

Commit 48a0fd7

Browse files
committed
handle concurrent task completion during cancellation
1 parent 9da9ecc commit 48a0fd7

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,21 @@ async def on_cancel_task(
180180

181181
consumer = EventConsumer(queue)
182182
result = await result_aggregator.consume_all(consumer)
183-
if isinstance(result, Task):
184-
return result
183+
if not isinstance(result, Task):
184+
raise ServerError(
185+
error=InternalError(
186+
message='Agent did not return valid response for cancel'
187+
)
188+
)
185189

186-
raise ServerError(
187-
error=InternalError(
188-
message='Agent did not return valid response for cancel'
190+
if result.status.state != TaskState.canceled:
191+
raise ServerError(
192+
error=TaskNotCancelableError(
193+
message=f'Task cannot be canceled - current state: {result.status.state}'
194+
)
189195
)
190-
)
196+
197+
return result
191198

192199
async def _run_event_stream(
193200
self, request: RequestContext, queue: EventQueue

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,56 @@ async def test_on_cancel_task_cancels_running_agent():
263263
mock_agent_executor.cancel.assert_awaited_once()
264264

265265

266+
@pytest.mark.asyncio
267+
async def test_on_cancel_task_completes_during_cancellation():
268+
"""Test on_cancel_task fails to cancel a task due to concurrent task completion."""
269+
task_id = 'running_agent_task_to_cancel'
270+
sample_task = create_sample_task(task_id=task_id)
271+
mock_task_store = AsyncMock(spec=TaskStore)
272+
mock_task_store.get.return_value = sample_task
273+
274+
mock_queue_manager = AsyncMock(spec=QueueManager)
275+
mock_event_queue = AsyncMock(spec=EventQueue)
276+
mock_queue_manager.tap.return_value = mock_event_queue
277+
278+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
279+
280+
# Mock ResultAggregator
281+
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
282+
mock_result_aggregator_instance.consume_all.return_value = (
283+
create_sample_task(task_id=task_id, status_state=TaskState.completed)
284+
)
285+
286+
request_handler = DefaultRequestHandler(
287+
agent_executor=mock_agent_executor,
288+
task_store=mock_task_store,
289+
queue_manager=mock_queue_manager,
290+
)
291+
292+
# Simulate a running agent task
293+
mock_producer_task = AsyncMock(spec=asyncio.Task)
294+
request_handler._running_agents[task_id] = mock_producer_task
295+
296+
from a2a.utils.errors import (
297+
ServerError, # Local import
298+
TaskNotCancelableError, # Local import
299+
)
300+
301+
with patch(
302+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
303+
return_value=mock_result_aggregator_instance,
304+
):
305+
params = TaskIdParams(id=task_id)
306+
with pytest.raises(ServerError) as exc_info:
307+
await request_handler.on_cancel_task(
308+
params, create_server_call_context()
309+
)
310+
311+
mock_producer_task.cancel.assert_called_once()
312+
mock_agent_executor.cancel.assert_awaited_once()
313+
assert isinstance(exc_info.value.error, TaskNotCancelableError)
314+
315+
266316
@pytest.mark.asyncio
267317
async def test_on_cancel_task_invalid_result_type():
268318
"""Test on_cancel_task when result_aggregator returns a Message instead of a Task."""

0 commit comments

Comments
 (0)