diff --git a/packages/celery-library/tests/unit/test_tasks.py b/packages/celery-library/tests/unit/test_tasks.py index 80a474ad04d8..757cbd3aaff2 100644 --- a/packages/celery-library/tests/unit/test_tasks.py +++ b/packages/celery-library/tests/unit/test_tasks.py @@ -210,6 +210,7 @@ async def test_listing_task_uuids_contains_submitted_task( async def test_filtering_listing_tasks( celery_task_manager: CeleryTaskManager, + with_celery_worker: WorkController, ): class MyFilter(TaskFilter): user_id: int @@ -218,39 +219,46 @@ class MyFilter(TaskFilter): user_id = 42 expected_task_uuids: set[TaskUUID] = set() + all_tasks: list[tuple[TaskUUID, MyFilter]] = [] + + try: + for _ in range(5): + task_filter = MyFilter( + user_id=user_id, + product_name=_faker.word(), + client_app=_faker.word(), + ) + task_uuid = await celery_task_manager.submit_task( + TaskMetadata( + name=dreamer_task.__name__, + ), + task_filter=task_filter, + ) + expected_task_uuids.add(task_uuid) + all_tasks.append((task_uuid, task_filter)) + + for _ in range(3): + task_filter = MyFilter( + user_id=_faker.pyint(min_value=100, max_value=200), + product_name=_faker.word(), + client_app=_faker.word(), + ) + task_uuid = await celery_task_manager.submit_task( + TaskMetadata( + name=dreamer_task.__name__, + ), + task_filter=task_filter, + ) + all_tasks.append((task_uuid, task_filter)) - for _ in range(5): - task_filter = MyFilter( + search_filter = MyFilter( user_id=user_id, - product_name=_faker.word(), - client_app=_faker.word(), - ) - task_uuid = await celery_task_manager.submit_task( - TaskMetadata( - name=dreamer_task.__name__, - ), - task_filter=task_filter, + product_name=Wildcard(), + client_app=Wildcard(), ) - expected_task_uuids.add(task_uuid) - - for _ in range(3): - task_filter = MyFilter( - user_id=_faker.pyint(min_value=100, max_value=200), - product_name=_faker.word(), - client_app=_faker.word(), - ) - await celery_task_manager.submit_task( - TaskMetadata( - name=dreamer_task.__name__, - ), - task_filter=task_filter, - ) - - search_filter = MyFilter( - user_id=user_id, - product_name=Wildcard(), - client_app=Wildcard(), - ) - tasks = await celery_task_manager.list_tasks(search_filter) - assert expected_task_uuids == {task.uuid for task in tasks} - await asyncio.sleep(5 * 60) + tasks = await celery_task_manager.list_tasks(search_filter) + assert expected_task_uuids == {task.uuid for task in tasks} + finally: + # clean up all tasks. this should ideally be done in the fixture + for task_uuid, task_filter in all_tasks: + await celery_task_manager.cancel_task(task_filter, task_uuid)