Skip to content

Commit bd9e25f

Browse files
authored
Merge branch 'master' into enh/sort-functions
2 parents 55a1b88 + edf0acc commit bd9e25f

File tree

14 files changed

+262
-83
lines changed

14 files changed

+262
-83
lines changed

packages/service-library/src/servicelib/aiohttp/long_running_tasks/_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from aiohttp import web
22

3-
from ...long_running_tasks.base_long_running_manager import BaseLongRunningManager
3+
from ...long_running_tasks.manager import LongRunningManager
44
from ...long_running_tasks.models import TaskContext
55
from ._constants import APP_LONG_RUNNING_MANAGER_KEY
66
from ._request import get_task_context
77

88

9-
class AiohttpLongRunningManager(BaseLongRunningManager):
9+
class AiohttpLongRunningManager(LongRunningManager):
1010

1111
@staticmethod
1212
def get_task_context(request: web.Request) -> TaskContext:

packages/service-library/src/servicelib/fastapi/long_running_tasks/_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from fastapi import Request
22

3-
from ...long_running_tasks.base_long_running_manager import BaseLongRunningManager
3+
from ...long_running_tasks.manager import LongRunningManager
44
from ...long_running_tasks.models import TaskContext
55

66

7-
class FastAPILongRunningManager(BaseLongRunningManager):
7+
class FastAPILongRunningManager(LongRunningManager):
88
@staticmethod
99
def get_task_context(request: Request) -> TaskContext:
1010
_ = request

packages/service-library/src/servicelib/long_running_tasks/_rpc_server.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
_logger = logging.getLogger(__name__)
1616

1717
if TYPE_CHECKING:
18-
from .base_long_running_manager import BaseLongRunningManager
18+
from .manager import LongRunningManager
1919

2020

2121
router = RPCRouter()
2222

2323

2424
@router.expose(reraise_if_error_type=(BaseLongRunningError,))
2525
async def start_task(
26-
long_running_manager: "BaseLongRunningManager",
26+
long_running_manager: "LongRunningManager",
2727
*,
2828
registered_task_name: RegisteredTaskName,
2929
unique: bool = False,
@@ -44,7 +44,7 @@ async def start_task(
4444

4545
@router.expose(reraise_if_error_type=(BaseLongRunningError,))
4646
async def list_tasks(
47-
long_running_manager: "BaseLongRunningManager", *, task_context: TaskContext
47+
long_running_manager: "LongRunningManager", *, task_context: TaskContext
4848
) -> list[TaskBase]:
4949
return await long_running_manager.tasks_manager.list_tasks(
5050
with_task_context=task_context
@@ -53,7 +53,7 @@ async def list_tasks(
5353

5454
@router.expose(reraise_if_error_type=(BaseLongRunningError,))
5555
async def get_task_status(
56-
long_running_manager: "BaseLongRunningManager",
56+
long_running_manager: "LongRunningManager",
5757
*,
5858
task_context: TaskContext,
5959
task_id: TaskId,
@@ -65,7 +65,7 @@ async def get_task_status(
6565

6666
@router.expose(reraise_if_error_type=(BaseLongRunningError, RPCTransferrableTaskError))
6767
async def get_task_result(
68-
long_running_manager: "BaseLongRunningManager",
68+
long_running_manager: "LongRunningManager",
6969
*,
7070
task_context: TaskContext,
7171
task_id: TaskId,
@@ -94,7 +94,7 @@ async def get_task_result(
9494

9595
@router.expose(reraise_if_error_type=(BaseLongRunningError,))
9696
async def remove_task(
97-
long_running_manager: "BaseLongRunningManager",
97+
long_running_manager: "LongRunningManager",
9898
*,
9999
task_context: TaskContext,
100100
task_id: TaskId,
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import logging
2+
3+
import redis.asyncio as aioredis
4+
from settings_library.redis import RedisDatabase, RedisSettings
5+
6+
from ..logging_utils import log_context
7+
from ..redis._client import RedisClientSDK
8+
from .models import LRTNamespace
9+
10+
_logger = logging.getLogger(__name__)
11+
12+
13+
class LongRunningClientHelper:
14+
def __init__(self, redis_settings: RedisSettings):
15+
self.redis_settings = redis_settings
16+
17+
self._client: RedisClientSDK | None = None
18+
19+
async def setup(self) -> None:
20+
self._client = RedisClientSDK(
21+
self.redis_settings.build_redis_dsn(RedisDatabase.LONG_RUNNING_TASKS),
22+
client_name="long_running_tasks_cleanup_client",
23+
)
24+
await self._client.setup()
25+
26+
async def shutdown(self) -> None:
27+
if self._client:
28+
await self._client.shutdown()
29+
30+
@property
31+
def _redis(self) -> aioredis.Redis:
32+
assert self._client # nosec
33+
return self._client.redis
34+
35+
async def cleanup(self, lrt_namespace: LRTNamespace) -> None:
36+
"""removes Redis keys associated to the LRTNamespace if they exist"""
37+
keys_to_remove: list[str] = [
38+
x async for x in self._redis.scan_iter(f"{lrt_namespace}*")
39+
]
40+
with log_context(
41+
_logger,
42+
logging.DEBUG,
43+
msg=f"Removing {keys_to_remove=} from Redis for {lrt_namespace=}",
44+
):
45+
if len(keys_to_remove) > 0:
46+
await self._redis.delete(*keys_to_remove)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .task import TasksManager
1212

1313

14-
class BaseLongRunningManager(ABC):
14+
class LongRunningManager(ABC):
1515
"""
1616
Provides a commond inteface for aiohttp and fastapi services
1717
"""

packages/service-library/src/servicelib/long_running_tasks/task.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from common_library.async_tools import cancel_wait_task
1212
from models_library.api_schemas_long_running_tasks.base import TaskProgress
1313
from pydantic import NonNegativeFloat, PositiveFloat
14+
from servicelib.utils import limited_gather
1415
from settings_library.redis import RedisDatabase, RedisSettings
1516
from tenacity import (
1617
AsyncRetrying,
@@ -21,7 +22,7 @@
2122

2223
from ..background_task import create_periodic_task
2324
from ..logging_errors import create_troubleshootting_log_kwargs
24-
from ..logging_utils import log_context
25+
from ..logging_utils import log_catch, log_context
2526
from ..redis import RedisClientSDK, exclusive
2627
from ._redis_store import RedisStore
2728
from ._serialization import dumps
@@ -50,6 +51,7 @@
5051
_STATUS_UPDATE_CHECK_INTERNAL: Final[datetime.timedelta] = datetime.timedelta(seconds=1)
5152
_MAX_EXCLUSIVE_TASK_CANCEL_TIMEOUT: Final[NonNegativeFloat] = 5
5253
_TASK_REMOVAL_MAX_WAIT: Final[NonNegativeFloat] = 60
54+
_PARALLEL_TASKS_CANCELLATION: Final[int] = 5
5355

5456
AllowedErrrors: TypeAlias = tuple[type[BaseException], ...]
5557

@@ -205,34 +207,40 @@ async def setup(self) -> None:
205207
await self._started_event_task_tasks_monitor.wait()
206208

207209
async def teardown(self) -> None:
208-
# ensure all created tasks are cancelled
209-
for tracked_task in await self._tasks_data.list_tasks_data():
210-
with suppress(TaskNotFoundError):
210+
# stop cancelled_tasks_removal
211+
if self._task_cancelled_tasks_removal:
212+
await cancel_wait_task(self._task_cancelled_tasks_removal)
213+
214+
# stopping only tasks that are handled by this manager
215+
# otherwise it will cancel long running tasks that were running in diffierent processes
216+
async def _remove_local_task(task_data: TaskData) -> None:
217+
with log_catch(_logger, reraise=False):
211218
await self.remove_task(
212-
tracked_task.task_id,
213-
tracked_task.task_context,
214-
wait_for_removal=True,
219+
task_data.task_id,
220+
task_data.task_context,
221+
wait_for_removal=False,
215222
)
223+
await self._attempt_to_remove_local_task(task_data.task_id)
224+
225+
await limited_gather(
226+
*[
227+
_remove_local_task(tracked_task)
228+
for task_id in self._created_tasks
229+
if (tracked_task := await self._tasks_data.get_task_data(task_id))
230+
is not None
231+
],
232+
log=_logger,
233+
limit=_PARALLEL_TASKS_CANCELLATION,
234+
)
216235

217-
for task in self._created_tasks.values():
218-
_logger.warning(
219-
"Task %s was not completed before shutdown, cancelling it",
220-
task.get_name(),
221-
)
222-
await cancel_wait_task(task)
223-
224-
# stale_tasks_monitor
236+
# stop stale_tasks_monitor
225237
if self._task_stale_tasks_monitor:
226238
await cancel_wait_task(
227239
self._task_stale_tasks_monitor,
228240
max_delay=_MAX_EXCLUSIVE_TASK_CANCEL_TIMEOUT,
229241
)
230242

231-
# cancelled_tasks_removal
232-
if self._task_cancelled_tasks_removal:
233-
await cancel_wait_task(self._task_cancelled_tasks_removal)
234-
235-
# tasks_monitor
243+
# stop tasks_monitor
236244
if self._task_tasks_monitor:
237245
await cancel_wait_task(self._task_tasks_monitor)
238246

packages/service-library/tests/long_running_tasks/conftest.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from faker import Faker
1212
from pytest_mock import MockerFixture
1313
from servicelib.logging_utils import log_catch
14-
from servicelib.long_running_tasks.base_long_running_manager import (
15-
BaseLongRunningManager,
14+
from servicelib.long_running_tasks.manager import (
15+
LongRunningManager,
1616
)
1717
from servicelib.long_running_tasks.models import LRTNamespace, TaskContext
1818
from servicelib.long_running_tasks.task import TasksManager
@@ -24,7 +24,7 @@
2424
_logger = logging.getLogger(__name__)
2525

2626

27-
class _TestingLongRunningManager(BaseLongRunningManager):
27+
class _TestingLongRunningManager(LongRunningManager):
2828
@staticmethod
2929
def get_task_context(request) -> TaskContext:
3030
_ = request
@@ -37,16 +37,16 @@ async def get_long_running_manager(
3737
) -> AsyncIterator[
3838
Callable[
3939
[RedisSettings, RabbitSettings, LRTNamespace | None],
40-
Awaitable[BaseLongRunningManager],
40+
Awaitable[LongRunningManager],
4141
]
4242
]:
43-
managers: list[BaseLongRunningManager] = []
43+
managers: list[LongRunningManager] = []
4444

4545
async def _(
4646
redis_settings: RedisSettings,
4747
rabbit_settings: RabbitSettings,
4848
lrt_namespace: LRTNamespace | None,
49-
) -> BaseLongRunningManager:
49+
) -> LongRunningManager:
5050
manager = _TestingLongRunningManager(
5151
stale_task_check_interval=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S),
5252
stale_task_detect_timeout=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S),
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# pylint:disable=redefined-outer-name
2+
3+
from collections.abc import AsyncIterable, Callable
4+
from contextlib import AbstractAsyncContextManager
5+
6+
import pytest
7+
from pydantic import TypeAdapter
8+
from servicelib.long_running_tasks._redis_store import RedisStore
9+
from servicelib.long_running_tasks.long_running_client_helper import (
10+
LongRunningClientHelper,
11+
)
12+
from servicelib.long_running_tasks.models import LRTNamespace, TaskData
13+
from servicelib.redis._client import RedisClientSDK
14+
from settings_library.redis import RedisDatabase, RedisSettings
15+
16+
17+
@pytest.fixture
18+
def task_data() -> TaskData:
19+
return TypeAdapter(TaskData).validate_python(
20+
TaskData.model_json_schema()["examples"][0]
21+
)
22+
23+
24+
@pytest.fixture
25+
def lrt_namespace() -> LRTNamespace:
26+
return "TEST-NAMESPACE"
27+
28+
29+
@pytest.fixture
30+
async def store(
31+
use_in_memory_redis: RedisSettings,
32+
get_redis_client_sdk: Callable[
33+
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
34+
],
35+
lrt_namespace: LRTNamespace,
36+
) -> AsyncIterable[RedisStore]:
37+
store = RedisStore(redis_settings=use_in_memory_redis, namespace=lrt_namespace)
38+
39+
await store.setup()
40+
yield store
41+
await store.shutdown()
42+
43+
# triggers cleanup of all redis data
44+
async with get_redis_client_sdk(RedisDatabase.LONG_RUNNING_TASKS):
45+
pass
46+
47+
48+
@pytest.fixture
49+
async def long_running_client_helper(
50+
use_in_memory_redis: RedisSettings,
51+
) -> AsyncIterable[LongRunningClientHelper]:
52+
helper = LongRunningClientHelper(redis_settings=use_in_memory_redis)
53+
54+
await helper.setup()
55+
yield helper
56+
await helper.shutdown()
57+
58+
59+
async def test_cleanup_namespace(
60+
store: RedisStore,
61+
task_data: TaskData,
62+
long_running_client_helper: LongRunningClientHelper,
63+
lrt_namespace: LRTNamespace,
64+
) -> None:
65+
# create entries in both sides
66+
await store.add_task_data(task_data.task_id, task_data)
67+
await store.mark_task_for_removal(task_data.task_id, task_data.task_context)
68+
69+
# entries exit
70+
assert await store.list_tasks_data() == [task_data]
71+
assert await store.list_tasks_to_remove() == {
72+
task_data.task_id: task_data.task_context
73+
}
74+
75+
# removes
76+
await long_running_client_helper.cleanup(lrt_namespace)
77+
78+
# entris were removed
79+
assert await store.list_tasks_data() == []
80+
assert await store.list_tasks_to_remove() == {}
81+
82+
# ensore it does not raise errors if there is nothing to remove
83+
await long_running_client_helper.cleanup(lrt_namespace)

0 commit comments

Comments
 (0)