|
6 | 6 |
|
7 | 7 | import asyncio |
8 | 8 | import urllib.parse |
9 | | -from collections.abc import AsyncIterator |
10 | | -from contextlib import asynccontextmanager |
| 9 | +from collections.abc import AsyncIterator, Awaitable, Callable |
11 | 10 | from datetime import datetime, timedelta |
12 | 11 | from typing import Any, Final |
13 | 12 |
|
@@ -77,29 +76,35 @@ def empty_context() -> TaskContext: |
77 | 76 | return {} |
78 | 77 |
|
79 | 78 |
|
80 | | -@asynccontextmanager |
81 | | -async def get_tasks_manager( |
82 | | - redis_settings: RedisSettings, |
83 | | -) -> AsyncIterator[TasksManager]: |
84 | | - tasks_manager = TasksManager( |
85 | | - stale_task_check_interval=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), |
86 | | - stale_task_detect_timeout=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), |
87 | | - redis_settings=redis_settings, |
88 | | - namespace="test", |
89 | | - ) |
90 | | - await tasks_manager.setup() |
| 79 | +@pytest.fixture |
| 80 | +async def get_tasks_manager() -> ( |
| 81 | + AsyncIterator[Callable[[RedisSettings], Awaitable[TasksManager]]] |
| 82 | +): |
| 83 | + managers: list[TasksManager] = [] |
| 84 | + |
| 85 | + async def _(redis_settings: RedisSettings) -> TasksManager: |
| 86 | + tasks_manager = TasksManager( |
| 87 | + stale_task_check_interval=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), |
| 88 | + stale_task_detect_timeout=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S), |
| 89 | + redis_settings=redis_settings, |
| 90 | + namespace="test", |
| 91 | + ) |
| 92 | + await tasks_manager.setup() |
| 93 | + managers.append(tasks_manager) |
| 94 | + return tasks_manager |
91 | 95 |
|
92 | | - yield tasks_manager |
| 96 | + yield _ |
93 | 97 |
|
94 | | - await tasks_manager.teardown() |
| 98 | + for manager in managers: |
| 99 | + await manager.teardown() |
95 | 100 |
|
96 | 101 |
|
97 | 102 | @pytest.fixture |
98 | 103 | async def tasks_manager( |
99 | 104 | use_in_memory_redis: RedisSettings, |
100 | | -) -> AsyncIterator[TasksManager]: |
101 | | - async with get_tasks_manager(use_in_memory_redis) as manager: |
102 | | - yield manager |
| 105 | + get_tasks_manager: Callable[[RedisSettings], Awaitable[TasksManager]], |
| 106 | +) -> TasksManager: |
| 107 | + return await get_tasks_manager(use_in_memory_redis) |
103 | 108 |
|
104 | 109 |
|
105 | 110 | @pytest.mark.parametrize("check_task_presence_before", [True, False]) |
@@ -325,34 +330,36 @@ async def test_get_result_finished_with_error( |
325 | 330 |
|
326 | 331 | async def test_cancel_task_from_different_manager( |
327 | 332 | use_in_memory_redis: RedisSettings, |
| 333 | + get_tasks_manager: Callable[[RedisSettings], Awaitable[TasksManager]], |
328 | 334 | empty_context: TaskContext, |
329 | 335 | ): |
330 | | - async with get_tasks_manager(use_in_memory_redis) as manager_1, get_tasks_manager( |
331 | | - use_in_memory_redis |
332 | | - ) as manager_2, get_tasks_manager(use_in_memory_redis) as manager_3: |
333 | | - task_id = await lrt_api.start_task( |
334 | | - manager_1, |
335 | | - a_background_task.__name__, |
336 | | - raise_when_finished=False, |
337 | | - total_sleep=1, |
338 | | - task_context=empty_context, |
339 | | - ) |
| 336 | + manager_1 = await get_tasks_manager(use_in_memory_redis) |
| 337 | + manager_2 = await get_tasks_manager(use_in_memory_redis) |
| 338 | + manager_3 = await get_tasks_manager(use_in_memory_redis) |
| 339 | + |
| 340 | + task_id = await lrt_api.start_task( |
| 341 | + manager_1, |
| 342 | + a_background_task.__name__, |
| 343 | + raise_when_finished=False, |
| 344 | + total_sleep=1, |
| 345 | + task_context=empty_context, |
| 346 | + ) |
340 | 347 |
|
341 | | - # wati for task to complete |
342 | | - for manager in (manager_1, manager_2, manager_3): |
343 | | - status = await manager.get_task_status(task_id, empty_context) |
344 | | - assert status.done is False |
345 | | - |
346 | | - async for attempt in AsyncRetrying(**_RETRY_PARAMS): |
347 | | - with attempt: |
348 | | - for manager in (manager_1, manager_2, manager_3): |
349 | | - status = await manager.get_task_status(task_id, empty_context) |
350 | | - assert status.done is True |
351 | | - |
352 | | - # check all provide the same result |
353 | | - for manager in (manager_1, manager_2, manager_3): |
354 | | - task_result = await manager.get_task_result(task_id, empty_context) |
355 | | - assert task_result == 42 |
| 348 | + # wati for task to complete |
| 349 | + for manager in (manager_1, manager_2, manager_3): |
| 350 | + status = await manager.get_task_status(task_id, empty_context) |
| 351 | + assert status.done is False |
| 352 | + |
| 353 | + async for attempt in AsyncRetrying(**_RETRY_PARAMS): |
| 354 | + with attempt: |
| 355 | + for manager in (manager_1, manager_2, manager_3): |
| 356 | + status = await manager.get_task_status(task_id, empty_context) |
| 357 | + assert status.done is True |
| 358 | + |
| 359 | + # check all provide the same result |
| 360 | + for manager in (manager_1, manager_2, manager_3): |
| 361 | + task_result = await manager.get_task_result(task_id, empty_context) |
| 362 | + assert task_result == 42 |
356 | 363 |
|
357 | 364 |
|
358 | 365 | async def test_remove_task(tasks_manager: TasksManager, empty_context: TaskContext): |
|
0 commit comments