diff --git a/src/a2a/server/tasks/task_updater.py b/src/a2a/server/tasks/task_updater.py index 6f8aca73..169f3b97 100644 --- a/src/a2a/server/tasks/task_updater.py +++ b/src/a2a/server/tasks/task_updater.py @@ -134,6 +134,30 @@ async def start_work(self, message: Message | None = None) -> None: message=message, ) + async def cancel(self, message: Message | None = None) -> None: + """Marks the task as cancelled and publishes a finalstatus update.""" + await self.update_status( + TaskState.canceled, message=message, final=True + ) + + async def requires_input( + self, message: Message | None = None, final: bool = False + ) -> None: + """Marks the task as input required and publishes a status update.""" + await self.update_status( + TaskState.input_required, + message=message, + final=final, + ) + + async def requires_auth( + self, message: Message | None = None, final: bool = False + ) -> None: + """Marks the task as auth required and publishes a status update.""" + await self.update_status( + TaskState.auth_required, message=message, final=final + ) + def new_agent_message( self, parts: list[Part], diff --git a/tests/server/tasks/test_task_updater.py b/tests/server/tasks/test_task_updater.py index 8b105b7b..1a23baf9 100644 --- a/tests/server/tasks/test_task_updater.py +++ b/tests/server/tasks/test_task_updater.py @@ -324,3 +324,151 @@ async def test_reject_with_message(task_updater, event_queue, sample_message): assert event.status.state == TaskState.rejected assert event.final is True assert event.status.message == sample_message + + +@pytest.mark.asyncio +async def test_requires_input_without_message(task_updater, event_queue): + """Test marking a task as input required without a message.""" + await task_updater.requires_input() + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.input_required + assert event.final is False + assert event.status.message is None + + +@pytest.mark.asyncio +async def test_requires_input_with_message( + task_updater, event_queue, sample_message +): + """Test marking a task as input required with a message.""" + await task_updater.requires_input(message=sample_message) + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.input_required + assert event.final is False + assert event.status.message == sample_message + + +@pytest.mark.asyncio +async def test_requires_input_final_true(task_updater, event_queue): + """Test marking a task as input required with final=True.""" + await task_updater.requires_input(final=True) + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.input_required + assert event.final is True + assert event.status.message is None + + +@pytest.mark.asyncio +async def test_requires_input_with_message_and_final( + task_updater, event_queue, sample_message +): + """Test marking a task as input required with message and final=True.""" + await task_updater.requires_input(message=sample_message, final=True) + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.input_required + assert event.final is True + assert event.status.message == sample_message + + +@pytest.mark.asyncio +async def test_requires_auth_without_message(task_updater, event_queue): + """Test marking a task as auth required without a message.""" + await task_updater.requires_auth() + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.auth_required + assert event.final is False + assert event.status.message is None + + +@pytest.mark.asyncio +async def test_requires_auth_with_message( + task_updater, event_queue, sample_message +): + """Test marking a task as auth required with a message.""" + await task_updater.requires_auth(message=sample_message) + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.auth_required + assert event.final is False + assert event.status.message == sample_message + + +@pytest.mark.asyncio +async def test_requires_auth_final_true(task_updater, event_queue): + """Test marking a task as auth required with final=True.""" + await task_updater.requires_auth(final=True) + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.auth_required + assert event.final is True + assert event.status.message is None + + +@pytest.mark.asyncio +async def test_requires_auth_with_message_and_final( + task_updater, event_queue, sample_message +): + """Test marking a task as auth required with message and final=True.""" + await task_updater.requires_auth(message=sample_message, final=True) + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.auth_required + assert event.final is True + assert event.status.message == sample_message + + +@pytest.mark.asyncio +async def test_cancel_without_message(task_updater, event_queue): + """Test marking a task as cancelled without a message.""" + await task_updater.cancel() + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.canceled + assert event.final is True + assert event.status.message is None + + +@pytest.mark.asyncio +async def test_cancel_with_message(task_updater, event_queue, sample_message): + """Test marking a task as cancelled with a message.""" + await task_updater.cancel(message=sample_message) + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.canceled + assert event.final is True + assert event.status.message == sample_message