diff --git a/tests/server/tasks/test_task_updater.py b/tests/server/tasks/test_task_updater.py index 1a633e6a..e80328b2 100644 --- a/tests/server/tasks/test_task_updater.py +++ b/tests/server/tasks/test_task_updater.py @@ -535,3 +535,33 @@ async def test_concurrent_updates_race_condition(event_queue): assert len(successes) == 1 assert len(failures) == 1 assert event_queue.enqueue_event.call_count == 1 + + +@pytest.mark.asyncio +async def test_reject_concurrently_with_complete(event_queue): + """Test for race conditions when reject and complete are called concurrently.""" + task_updater = TaskUpdater( + event_queue=event_queue, + task_id='concurrent-task', + context_id='concurrent-context', + ) + + tasks = [ + task_updater.reject(), + task_updater.complete(), + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + successes = [r for r in results if not isinstance(r, Exception)] + failures = [r for r in results if isinstance(r, RuntimeError)] + + assert len(successes) == 1 + assert len(failures) == 1 + + assert event_queue.enqueue_event.call_count == 1 + + event = event_queue.enqueue_event.call_args[0][0] + assert isinstance(event, TaskStatusUpdateEvent) + assert event.final is True + assert event.status.state in [TaskState.rejected, TaskState.completed]