@@ -263,6 +263,56 @@ async def test_on_cancel_task_cancels_running_agent():
263263 mock_agent_executor .cancel .assert_awaited_once ()
264264
265265
266+ @pytest .mark .asyncio
267+ async def test_on_cancel_task_completes_during_cancellation ():
268+ """Test on_cancel_task fails to cancel a task due to concurrent task completion."""
269+ task_id = 'running_agent_task_to_cancel'
270+ sample_task = create_sample_task (task_id = task_id )
271+ mock_task_store = AsyncMock (spec = TaskStore )
272+ mock_task_store .get .return_value = sample_task
273+
274+ mock_queue_manager = AsyncMock (spec = QueueManager )
275+ mock_event_queue = AsyncMock (spec = EventQueue )
276+ mock_queue_manager .tap .return_value = mock_event_queue
277+
278+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
279+
280+ # Mock ResultAggregator
281+ mock_result_aggregator_instance = AsyncMock (spec = ResultAggregator )
282+ mock_result_aggregator_instance .consume_all .return_value = (
283+ create_sample_task (task_id = task_id , status_state = TaskState .completed )
284+ )
285+
286+ request_handler = DefaultRequestHandler (
287+ agent_executor = mock_agent_executor ,
288+ task_store = mock_task_store ,
289+ queue_manager = mock_queue_manager ,
290+ )
291+
292+ # Simulate a running agent task
293+ mock_producer_task = AsyncMock (spec = asyncio .Task )
294+ request_handler ._running_agents [task_id ] = mock_producer_task
295+
296+ from a2a .utils .errors import (
297+ ServerError , # Local import
298+ TaskNotCancelableError , # Local import
299+ )
300+
301+ with patch (
302+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
303+ return_value = mock_result_aggregator_instance ,
304+ ):
305+ params = TaskIdParams (id = task_id )
306+ with pytest .raises (ServerError ) as exc_info :
307+ await request_handler .on_cancel_task (
308+ params , create_server_call_context ()
309+ )
310+
311+ mock_producer_task .cancel .assert_called_once ()
312+ mock_agent_executor .cancel .assert_awaited_once ()
313+ assert isinstance (exc_info .value .error , TaskNotCancelableError )
314+
315+
266316@pytest .mark .asyncio
267317async def test_on_cancel_task_invalid_result_type ():
268318 """Test on_cancel_task when result_aggregator returns a Message instead of a Task."""
0 commit comments