Skip to content

Commit bfee5d4

Browse files
authored
Merge branch 'main' into fix_task_execution_cancelled
2 parents 1acb483 + 42ff0d4 commit bfee5d4

File tree

4 files changed

+66
-7
lines changed

4 files changed

+66
-7
lines changed

src/a2a/server/events/event_consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ async def consume_all(self) -> AsyncGenerator[Event]:
133133
# continue polling until there is a final event
134134
continue
135135
except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
136-
# This class was made an alias of build-in TimeoutError after 3.11
136+
# This class was made an alias of built-in TimeoutError after 3.11
137137
continue
138138
except (QueueClosed, asyncio.QueueEmpty):
139139
# Confirm that the queue is closed, e.g. we aren't on

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,21 @@ async def on_cancel_task(
180180

181181
consumer = EventConsumer(queue)
182182
result = await result_aggregator.consume_all(consumer)
183-
if isinstance(result, Task):
184-
return result
183+
if not isinstance(result, Task):
184+
raise ServerError(
185+
error=InternalError(
186+
message='Agent did not return valid response for cancel'
187+
)
188+
)
185189

186-
raise ServerError(
187-
error=InternalError(
188-
message='Agent did not return valid response for cancel'
190+
if result.status.state != TaskState.canceled:
191+
raise ServerError(
192+
error=TaskNotCancelableError(
193+
message=f'Task cannot be canceled - current state: {result.status.state}'
194+
)
189195
)
190-
)
196+
197+
return result
191198

192199
async def _run_event_stream(
193200
self, request: RequestContext, queue: EventQueue

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,56 @@ async def test_on_cancel_task_cancels_running_agent():
263263
mock_agent_executor.cancel.assert_awaited_once()
264264

265265

266+
@pytest.mark.asyncio
267+
async def test_on_cancel_task_completes_during_cancellation():
268+
"""Test on_cancel_task fails to cancel a task due to concurrent task completion."""
269+
task_id = 'running_agent_task_to_cancel'
270+
sample_task = create_sample_task(task_id=task_id)
271+
mock_task_store = AsyncMock(spec=TaskStore)
272+
mock_task_store.get.return_value = sample_task
273+
274+
mock_queue_manager = AsyncMock(spec=QueueManager)
275+
mock_event_queue = AsyncMock(spec=EventQueue)
276+
mock_queue_manager.tap.return_value = mock_event_queue
277+
278+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
279+
280+
# Mock ResultAggregator
281+
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
282+
mock_result_aggregator_instance.consume_all.return_value = (
283+
create_sample_task(task_id=task_id, status_state=TaskState.completed)
284+
)
285+
286+
request_handler = DefaultRequestHandler(
287+
agent_executor=mock_agent_executor,
288+
task_store=mock_task_store,
289+
queue_manager=mock_queue_manager,
290+
)
291+
292+
# Simulate a running agent task
293+
mock_producer_task = AsyncMock(spec=asyncio.Task)
294+
request_handler._running_agents[task_id] = mock_producer_task
295+
296+
from a2a.utils.errors import (
297+
ServerError, # Local import
298+
TaskNotCancelableError, # Local import
299+
)
300+
301+
with patch(
302+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
303+
return_value=mock_result_aggregator_instance,
304+
):
305+
params = TaskIdParams(id=task_id)
306+
with pytest.raises(ServerError) as exc_info:
307+
await request_handler.on_cancel_task(
308+
params, create_server_call_context()
309+
)
310+
311+
mock_producer_task.cancel.assert_called_once()
312+
mock_agent_executor.cancel.assert_awaited_once()
313+
assert isinstance(exc_info.value.error, TaskNotCancelableError)
314+
315+
266316
@pytest.mark.asyncio
267317
async def test_on_cancel_task_invalid_result_type():
268318
"""Test on_cancel_task when result_aggregator returns a Message instead of a Task."""

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ async def test_on_cancel_task_success(self) -> None:
150150
call_context = ServerCallContext(state={'foo': 'bar'})
151151

152152
async def streaming_coro():
153+
mock_task.status.state = TaskState.canceled
153154
yield mock_task
154155

155156
with patch(
@@ -161,6 +162,7 @@ async def streaming_coro():
161162
assert mock_agent_executor.cancel.call_count == 1
162163
self.assertIsInstance(response.root, CancelTaskSuccessResponse)
163164
assert response.root.result == mock_task # type: ignore
165+
assert response.root.result.status.state == TaskState.canceled
164166
mock_agent_executor.cancel.assert_called_once()
165167

166168
async def test_on_cancel_task_not_supported(self) -> None:

0 commit comments

Comments
 (0)