diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/long_running_tasks.py b/packages/pytest-simcore/src/pytest_simcore/helpers/long_running_tasks.py new file mode 100644 index 000000000000..ad85744951f1 --- /dev/null +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/long_running_tasks.py @@ -0,0 +1,39 @@ +# pylint: disable=protected-access + +import pytest +from fastapi import FastAPI +from servicelib.long_running_tasks.errors import TaskNotFoundError +from servicelib.long_running_tasks.manager import ( + LongRunningManager, +) +from servicelib.long_running_tasks.models import TaskContext +from servicelib.long_running_tasks.task import TaskId +from tenacity import ( + AsyncRetrying, + retry_if_not_exception_type, + stop_after_delay, + wait_fixed, +) + + +def get_fastapi_long_running_manager(app: FastAPI) -> LongRunningManager: + manager = app.state.long_running_manager + assert isinstance(manager, LongRunningManager) + return manager + + +async def assert_task_is_no_longer_present( + manager: LongRunningManager, task_id: TaskId, task_context: TaskContext +) -> None: + async for attempt in AsyncRetrying( + reraise=True, + wait=wait_fixed(0.1), + stop=stop_after_delay(60), + retry=retry_if_not_exception_type(TaskNotFoundError), + ): + with attempt: # noqa: SIM117 + with pytest.raises(TaskNotFoundError): + # use internals to detirmine when it's no longer here + await manager._tasks_manager._get_tracked_task( # noqa: SLF001 + task_id, task_context + ) diff --git a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py index 9e8ac646c330..55879e34ef13 100644 --- a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py +++ b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py @@ -1,15 +1,13 @@ -from typing import Annotated, Any +from typing import Any from aiohttp import web -from models_library.rest_base import RequestParameters -from pydantic import BaseModel, Field +from pydantic import BaseModel from ...aiohttp import status from ...long_running_tasks import lrt_api from ...long_running_tasks.models import TaskGet, TaskId from ..requests_validation import ( parse_request_path_parameters_as, - parse_request_query_parameters_as, ) from ..rest_responses import create_data_response from ._manager import get_long_running_manager @@ -69,22 +67,9 @@ async def get_task_result(request: web.Request) -> web.Response | Any: ) -class _RemoveTaskQueryParams(RequestParameters): - wait_for_removal: Annotated[ - bool, - Field( - description=( - "when True waits for the task to be removed " - "completly instead of returning immediately" - ) - ), - ] = True - - @routes.delete("/{task_id}", name="remove_task") async def remove_task(request: web.Request) -> web.Response: path_params = parse_request_path_parameters_as(_PathParam, request) - query_params = parse_request_query_parameters_as(_RemoveTaskQueryParams, request) long_running_manager = get_long_running_manager(request.app) await lrt_api.remove_task( @@ -92,6 +77,5 @@ async def remove_task(request: web.Request) -> web.Response: long_running_manager.lrt_namespace, long_running_manager.get_task_context(request), path_params.task_id, - wait_for_removal=query_params.wait_for_removal, ) return web.json_response(status=status.HTTP_204_NO_CONTENT) diff --git a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py index b5ae54cb07f9..09c50be9685e 100644 --- a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py +++ b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py @@ -108,7 +108,6 @@ async def start_long_running_task( long_running_manager.lrt_namespace, task_context, task_id, - wait_for_removal=True, ) raise diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py index bf347ba0d0aa..95b28ceec2fb 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py @@ -1,6 +1,6 @@ from typing import Annotated, Any -from fastapi import APIRouter, Depends, Query, Request, status +from fastapi import APIRouter, Depends, Request, status from ...long_running_tasks import lrt_api from ...long_running_tasks.models import TaskGet, TaskId, TaskResult, TaskStatus @@ -101,16 +101,6 @@ async def remove_task( FastAPILongRunningManager, Depends(get_long_running_manager) ], task_id: TaskId, - *, - wait_for_removal: Annotated[ - bool, - Query( - description=( - "when True waits for the task to be removed " - "completly instead of returning immediately" - ), - ), - ] = True, ) -> None: assert request # nosec await lrt_api.remove_task( @@ -118,5 +108,4 @@ async def remove_task( long_running_manager.lrt_namespace, long_running_manager.get_task_context(request), task_id=task_id, - wait_for_removal=wait_for_removal, ) diff --git a/packages/service-library/src/servicelib/long_running_tasks/_redis_store.py b/packages/service-library/src/servicelib/long_running_tasks/_redis_store.py index acf70bb87e48..fbed41205a95 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/_redis_store.py +++ b/packages/service-library/src/servicelib/long_running_tasks/_redis_store.py @@ -8,11 +8,11 @@ from ..redis._client import RedisClientSDK from ..redis._utils import handle_redis_returns_union_types from ..utils import limited_gather -from .models import LRTNamespace, TaskContext, TaskData, TaskId +from .models import LRTNamespace, TaskData, TaskId _STORE_TYPE_TASK_DATA: Final[str] = "TD" -_STORE_TYPE_CANCELLED_TASKS: Final[str] = "CT" -_LIST_CONCURRENCY: Final[int] = 2 +_LIST_CONCURRENCY: Final[int] = 3 +_MARKED_FOR_REMOVAL_FIELD: Final[str] = "marked_for_removal" def _to_redis_hash_mapping(data: dict[str, Any]) -> dict[str, str]: @@ -52,11 +52,6 @@ def _get_redis_key_task_data_match(self) -> str: def _get_redis_task_data_key(self, task_id: TaskId) -> str: return f"{self.namespace}:{_STORE_TYPE_TASK_DATA}:{task_id}" - def _get_key_to_remove(self) -> str: - return f"{self.namespace}:{_STORE_TYPE_CANCELLED_TASKS}" - - # TaskData - async def get_task_data(self, task_id: TaskId) -> TaskData | None: result: dict[str, Any] = await handle_redis_returns_union_types( self._redis.hgetall( @@ -115,24 +110,18 @@ async def delete_task_data(self, task_id: TaskId) -> None: self._redis.delete(self._get_redis_task_data_key(task_id)) ) - # to cancel - - async def mark_task_for_removal( - self, task_id: TaskId, with_task_context: TaskContext - ) -> None: + async def mark_for_removal(self, task_id: TaskId) -> None: await handle_redis_returns_union_types( self._redis.hset( - self._get_key_to_remove(), task_id, json_dumps(with_task_context) + self._get_redis_task_data_key(task_id), + mapping=_to_redis_hash_mapping({_MARKED_FOR_REMOVAL_FIELD: True}), ) ) - async def completed_task_removal(self, task_id: TaskId) -> None: - await handle_redis_returns_union_types( - self._redis.hdel(self._get_key_to_remove(), task_id) - ) - - async def list_tasks_to_remove(self) -> dict[TaskId, TaskContext]: - result: dict[str, str | None] = await handle_redis_returns_union_types( - self._redis.hgetall(self._get_key_to_remove()) + async def is_marked_for_removal(self, task_id: TaskId) -> bool: + result = await handle_redis_returns_union_types( + self._redis.hget( + self._get_redis_task_data_key(task_id), _MARKED_FOR_REMOVAL_FIELD + ) ) - return {task_id: json_loads(context) for task_id, context in result.items()} + return False if result is None else json_loads(result) diff --git a/packages/service-library/src/servicelib/long_running_tasks/_rpc_client.py b/packages/service-library/src/servicelib/long_running_tasks/_rpc_client.py index 3bc3caf5e804..6ad3fe9785ba 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/_rpc_client.py +++ b/packages/service-library/src/servicelib/long_running_tasks/_rpc_client.py @@ -118,26 +118,13 @@ async def remove_task( *, task_context: TaskContext, task_id: TaskId, - wait_for_removal: bool, - cancellation_timeout: timedelta | None, ) -> None: - timeout_s = ( - None - if cancellation_timeout is None - else int(cancellation_timeout.total_seconds()) - ) - - # NOTE: task always gets cancelled even if not waiting for it - # request will return immediatlye, no need to wait so much - if wait_for_removal is False: - timeout_s = _RPC_TIMEOUT_SHORT_REQUESTS result = await rabbitmq_rpc_client.request( get_rabbit_namespace(namespace), TypeAdapter(RPCMethodName).validate_python("remove_task"), task_context=task_context, task_id=task_id, - wait_for_removal=wait_for_removal, - timeout_s=timeout_s, + timeout_s=_RPC_TIMEOUT_SHORT_REQUESTS, ) assert result is None # nosec diff --git a/packages/service-library/src/servicelib/long_running_tasks/_rpc_server.py b/packages/service-library/src/servicelib/long_running_tasks/_rpc_server.py index bf9edfdadc66..2d7ff79ac087 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/_rpc_server.py +++ b/packages/service-library/src/servicelib/long_running_tasks/_rpc_server.py @@ -88,7 +88,7 @@ async def get_task_result( await long_running_manager.tasks_manager.remove_task( task_id, with_task_context=task_context, - wait_for_removal=True, + wait_for_removal=False, ) @@ -98,10 +98,7 @@ async def remove_task( *, task_context: TaskContext, task_id: TaskId, - wait_for_removal: bool, ) -> None: await long_running_manager.tasks_manager.remove_task( - task_id, - with_task_context=task_context, - wait_for_removal=wait_for_removal, + task_id, with_task_context=task_context, wait_for_removal=False ) diff --git a/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py b/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py index 73fdebb4cfa9..02f4c5265f38 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py +++ b/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py @@ -1,4 +1,3 @@ -from datetime import timedelta from typing import Any from ..rabbitmq._client_rpc import RabbitMQRPCClient @@ -103,9 +102,6 @@ async def remove_task( lrt_namespace: LRTNamespace, task_context: TaskContext, task_id: TaskId, - *, - wait_for_removal: bool, - cancellation_timeout: timedelta | None = None, ) -> None: """cancels and removes a task @@ -116,6 +112,4 @@ async def remove_task( lrt_namespace, task_id=task_id, task_context=task_context, - wait_for_removal=wait_for_removal, - cancellation_timeout=cancellation_timeout, ) diff --git a/packages/service-library/src/servicelib/long_running_tasks/models.py b/packages/service-library/src/servicelib/long_running_tasks/models.py index 7a99f9a5cf34..193c5eadbde3 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/models.py +++ b/packages/service-library/src/servicelib/long_running_tasks/models.py @@ -50,7 +50,7 @@ class TaskData(BaseModel): task_id: str task_progress: TaskProgress # NOTE: this context lifetime is with the tracked task (similar to aiohttp storage concept) - task_context: dict[str, Any] + task_context: TaskContext fire_and_forget: Annotated[ bool, Field( @@ -78,6 +78,10 @@ class TaskData(BaseModel): result_field: Annotated[ ResultField | None, Field(description="the result of the task") ] = None + marked_for_removal: Annotated[ + bool, + Field(description=("if True, indicates the task is marked for removal")), + ] = False model_config = ConfigDict( arbitrary_types_allowed=True, diff --git a/packages/service-library/src/servicelib/long_running_tasks/task.py b/packages/service-library/src/servicelib/long_running_tasks/task.py index dfb79f29662c..256a60fb4b59 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/task.py +++ b/packages/service-library/src/servicelib/long_running_tasks/task.py @@ -282,7 +282,7 @@ async def _stale_tasks_monitor(self) -> None: # we just print the status from where one can infer the above with suppress(TaskNotFoundError): task_status = await self.get_task_status( - task_id, with_task_context=task_context + task_id, with_task_context=task_context, exclude_to_remove=False ) with log_context( _logger, @@ -300,11 +300,17 @@ async def _cancelled_tasks_removal(self) -> None: """ self._started_event_task_cancelled_tasks_removal.set() - to_remove = await self._tasks_data.list_tasks_to_remove() - for task_id in to_remove: - await self._attempt_to_remove_local_task(task_id) + tasks_data = await self._tasks_data.list_tasks_data() + await limited_gather( + *( + self._attempt_to_remove_local_task(x.task_id) + for x in tasks_data + if x.marked_for_removal is True + ), + limit=_PARALLEL_TASKS_CANCELLATION, + ) - async def _tasks_monitor(self) -> None: + async def _tasks_monitor(self) -> None: # noqa: C901 """ A task which monitors locally running tasks and updates their status in the Redis store when they are done. @@ -396,12 +402,14 @@ async def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBa return [ TaskBase(task_id=task.task_id) for task in (await self._tasks_data.list_tasks_data()) + if task.marked_for_removal is False ] return [ TaskBase(task_id=task.task_id) for task in (await self._tasks_data.list_tasks_data()) if task.task_context == with_task_context + and task.marked_for_removal is False ] async def _get_tracked_task( @@ -418,7 +426,11 @@ async def _get_tracked_task( return task_data async def get_task_status( - self, task_id: TaskId, with_task_context: TaskContext + self, + task_id: TaskId, + with_task_context: TaskContext, + *, + exclude_to_remove: bool = True, ) -> TaskStatus: """ returns: the status of the task, along with updates @@ -426,6 +438,9 @@ async def get_task_status( raises TaskNotFoundError if the task cannot be found """ + if exclude_to_remove and await self._tasks_data.is_marked_for_removal(task_id): + raise TaskNotFoundError(task_id=task_id) + task_data = await self._get_tracked_task(task_id, with_task_context) await self._tasks_data.update_task_data( @@ -460,6 +475,9 @@ async def get_task_result( raises TaskNotFoundError if the task cannot be found raises TaskNotCompletedError if the task is not completed """ + if await self._tasks_data.is_marked_for_removal(task_id): + raise TaskNotFoundError(task_id=task_id) + tracked_task = await self._get_tracked_task(task_id, with_task_context) if not tracked_task.is_done or tracked_task.result_field is None: @@ -473,7 +491,6 @@ async def _attempt_to_remove_local_task(self, task_id: TaskId) -> None: task_to_cancel = self._created_tasks.pop(task_id, None) if task_to_cancel is not None: await cancel_wait_task(task_to_cancel) - await self._tasks_data.completed_task_removal(task_id) await self._tasks_data.delete_task_data(task_id) async def remove_task( @@ -487,11 +504,12 @@ async def remove_task( cancels and removes task raises TaskNotFoundError if the task cannot be found """ + if await self._tasks_data.is_marked_for_removal(task_id): + raise TaskNotFoundError(task_id=task_id) + tracked_task = await self._get_tracked_task(task_id, with_task_context) - await self._tasks_data.mark_task_for_removal( - tracked_task.task_id, tracked_task.task_context - ) + await self._tasks_data.mark_for_removal(tracked_task.task_id) if not wait_for_removal: return diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__store.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__redis_store.py similarity index 73% rename from packages/service-library/tests/long_running_tasks/test_long_running_tasks__store.py rename to packages/service-library/tests/long_running_tasks/test_long_running_tasks__redis_store.py index 218af7a9aaae..fc08de586864 100644 --- a/packages/service-library/tests/long_running_tasks/test_long_running_tasks__store.py +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks__redis_store.py @@ -2,15 +2,26 @@ from collections.abc import AsyncIterable, Callable from contextlib import AbstractAsyncContextManager +from copy import deepcopy import pytest from pydantic import TypeAdapter -from servicelib.long_running_tasks._redis_store import RedisStore +from servicelib.long_running_tasks._redis_store import ( + _MARKED_FOR_REMOVAL_FIELD, + RedisStore, +) from servicelib.long_running_tasks.models import TaskData from servicelib.redis._client import RedisClientSDK from settings_library.redis import RedisDatabase, RedisSettings +def test_ensure_task_data_field_name_and_type(): + # NOTE: ensure thse do not change, if you want to change them remeber that the db is invalid + assert _MARKED_FOR_REMOVAL_FIELD == "marked_for_removal" + field = TaskData.model_fields[_MARKED_FOR_REMOVAL_FIELD] + assert field.annotation is bool + + @pytest.fixture def task_data() -> TaskData: return TypeAdapter(TaskData).validate_python( @@ -50,13 +61,13 @@ async def test_workflow(store: RedisStore, task_data: TaskData) -> None: assert await store.list_tasks_data() == [] # cancelled tasks - assert await store.list_tasks_to_remove() == {} + await store.add_task_data(task_data.task_id, task_data) + + assert await store.is_marked_for_removal(task_data.task_id) is False - await store.mark_task_for_removal(task_data.task_id, task_data.task_context) + await store.mark_for_removal(task_data.task_id) - assert await store.list_tasks_to_remove() == { - task_data.task_id: task_data.task_context - } + assert await store.is_marked_for_removal(task_data.task_id) is True @pytest.fixture @@ -89,15 +100,15 @@ async def test_workflow_multiple_redis_stores_with_different_namespaces( for store in redis_stores: assert await store.list_tasks_data() == [] - assert await store.list_tasks_to_remove() == {} for store in redis_stores: await store.add_task_data(task_data.task_id, task_data) - await store.mark_task_for_removal(task_data.task_id, {}) + await store.mark_for_removal(task_data.task_id) + marked_as_removed_task_data = deepcopy(task_data) + marked_as_removed_task_data.marked_for_removal = True for store in redis_stores: - assert await store.list_tasks_data() == [task_data] - assert await store.list_tasks_to_remove() == {task_data.task_id: {}} + assert await store.list_tasks_data() == [marked_as_removed_task_data] for store in redis_stores: await store.delete_task_data(task_data.task_id) diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_client_long_running_manager.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_client_long_running_manager.py index 236dd2f30814..27369fe08d64 100644 --- a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_client_long_running_manager.py +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_client_long_running_manager.py @@ -2,6 +2,7 @@ from collections.abc import AsyncIterable, Callable from contextlib import AbstractAsyncContextManager +from copy import deepcopy import pytest from pydantic import TypeAdapter @@ -64,20 +65,18 @@ async def test_cleanup_namespace( ) -> None: # create entries in both sides await store.add_task_data(task_data.task_id, task_data) - await store.mark_task_for_removal(task_data.task_id, task_data.task_context) + await store.mark_for_removal(task_data.task_id) # entries exit - assert await store.list_tasks_data() == [task_data] - assert await store.list_tasks_to_remove() == { - task_data.task_id: task_data.task_context - } + marked_for_removal = deepcopy(task_data) + marked_for_removal.marked_for_removal = True + assert await store.list_tasks_data() == [marked_for_removal] # removes await long_running_client_helper.cleanup(lrt_namespace) # entris were removed assert await store.list_tasks_data() == [] - assert await store.list_tasks_to_remove() == {} # ensore it does not raise errors if there is nothing to remove await long_running_client_helper.cleanup(lrt_namespace) diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_lrt_api.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_lrt_api.py index 8b998f700ad1..88e464ee5b01 100644 --- a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_lrt_api.py +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_lrt_api.py @@ -1,3 +1,4 @@ +# pylint: disable=protected-access # pylint: disable=redefined-outer-name # pylint: disable=unused-argument @@ -9,11 +10,9 @@ import pytest from models_library.api_schemas_long_running_tasks.base import TaskProgress from pydantic import NonNegativeInt +from pytest_simcore.helpers.long_running_tasks import assert_task_is_no_longer_present from servicelib.long_running_tasks import lrt_api -from servicelib.long_running_tasks.errors import TaskNotFoundError -from servicelib.long_running_tasks.manager import ( - LongRunningManager, -) +from servicelib.long_running_tasks.manager import LongRunningManager from servicelib.long_running_tasks.models import LRTNamespace, TaskContext from servicelib.long_running_tasks.task import TaskId, TaskRegistry from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient @@ -160,21 +159,6 @@ async def _assert_list_tasks_from_all_managers( assert len(tasks) == task_count -async def _assert_task_is_no_longer_present( - rabbitmq_rpc_client: RabbitMQRPCClient, - long_running_managers: list[LongRunningManager], - task_context: TaskContext, - task_id: TaskId, -) -> None: - with pytest.raises(TaskNotFoundError): - await lrt_api.get_task_status( - rabbitmq_rpc_client, - _get_long_running_manager(long_running_managers).lrt_namespace, - task_context, - task_id, - ) - - _TASK_CONTEXT: Final[list[TaskContext | None]] = [{"a": "context"}, None] _IS_UNIQUE: Final[list[bool]] = [False, True] _TASK_COUNT: Final[list[int]] = [5] @@ -185,6 +169,8 @@ async def _assert_task_is_no_longer_present( @pytest.mark.parametrize("is_unique", _IS_UNIQUE) @pytest.mark.parametrize("to_return", [{"key": "value"}]) async def test_workflow_with_result( + disable_stale_tasks_monitor: None, + fast_long_running_tasks_cancellation: None, long_running_managers: list[LongRunningManager], rabbitmq_rpc_client: RabbitMQRPCClient, task_count: int, @@ -232,8 +218,8 @@ async def test_workflow_with_result( ) assert result == to_return - await _assert_task_is_no_longer_present( - rabbitmq_rpc_client, long_running_managers, saved_context, task_id + await assert_task_is_no_longer_present( + _get_long_running_manager(long_running_managers), task_id, saved_context ) @@ -241,6 +227,8 @@ async def test_workflow_with_result( @pytest.mark.parametrize("task_context", _TASK_CONTEXT) @pytest.mark.parametrize("is_unique", _IS_UNIQUE) async def test_workflow_raises_error( + disable_stale_tasks_monitor: None, + fast_long_running_tasks_cancellation: None, long_running_managers: list[LongRunningManager], rabbitmq_rpc_client: RabbitMQRPCClient, task_count: int, @@ -286,14 +274,16 @@ async def test_workflow_raises_error( task_id, ) - await _assert_task_is_no_longer_present( - rabbitmq_rpc_client, long_running_managers, saved_context, task_id + await assert_task_is_no_longer_present( + _get_long_running_manager(long_running_managers), task_id, saved_context ) @pytest.mark.parametrize("task_context", _TASK_CONTEXT) @pytest.mark.parametrize("is_unique", _IS_UNIQUE) async def test_remove_task( + disable_stale_tasks_monitor: None, + fast_long_running_tasks_cancellation: None, long_running_managers: list[LongRunningManager], rabbitmq_rpc_client: RabbitMQRPCClient, is_unique: bool, @@ -319,9 +309,8 @@ async def test_remove_task( _get_long_running_manager(long_running_managers).lrt_namespace, saved_context, task_id, - wait_for_removal=True, ) - await _assert_task_is_no_longer_present( - rabbitmq_rpc_client, long_running_managers, saved_context, task_id + await assert_task_is_no_longer_present( + _get_long_running_manager(long_running_managers), task_id, saved_context ) diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py index d7586026e49b..0808878818a9 100644 --- a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py @@ -548,8 +548,8 @@ async def test__cancelled_tasks_worker_equivalent_of_cancellation_from_a_differe total_sleep=10, task_context=empty_context, ) - await long_running_manager.tasks_manager._tasks_data.mark_task_for_removal( # noqa: SLF001 - task_id, with_task_context=empty_context + await long_running_manager.tasks_manager._tasks_data.mark_for_removal( # noqa: SLF001 + task_id ) async for attempt in AsyncRetrying(**_RETRY_PARAMS): diff --git a/services/dynamic-sidecar/tests/unit/api/rest/test_containers_long_running_tasks.py b/services/dynamic-sidecar/tests/unit/api/rest/test_containers_long_running_tasks.py index 4040ee7ffe96..c8d97fb52d8c 100644 --- a/services/dynamic-sidecar/tests/unit/api/rest/test_containers_long_running_tasks.py +++ b/services/dynamic-sidecar/tests/unit/api/rest/test_containers_long_running_tasks.py @@ -32,6 +32,10 @@ from models_library.services_creation import CreateServiceMetricsAdditionalParams from pydantic import AnyHttpUrl, TypeAdapter from pytest_mock.plugin import MockerFixture +from pytest_simcore.helpers.long_running_tasks import ( + assert_task_is_no_longer_present, + get_fastapi_long_running_manager, +) from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict, setenvs_from_dict from servicelib.fastapi.long_running_tasks.client import ( HttpClient, @@ -39,7 +43,7 @@ ) from servicelib.fastapi.long_running_tasks.client import setup as client_setup from servicelib.long_running_tasks.errors import TaskExceptionError -from servicelib.long_running_tasks.models import TaskId +from servicelib.long_running_tasks.models import ProgressCallback, TaskId from servicelib.long_running_tasks.task import TaskRegistry from settings_library.rabbit import RabbitSettings from simcore_sdk.node_ports_common.exceptions import NodeNotFound @@ -61,8 +65,8 @@ "rabbit", ] -FAST_STATUS_POLL: Final[float] = 0.1 -CREATE_SERVICE_CONTAINERS_TIMEOUT: Final[float] = 60 +_FAST_STATUS_POLL: Final[float] = 0.1 +_CREATE_SERVICE_CONTAINERS_TIMEOUT: Final[float] = 60 DEFAULT_COMMAND_TIMEOUT: Final[int] = 5 @@ -118,6 +122,9 @@ async def auto_remove_task( yield finally: await http_client.remove_task(task_id, timeout=10) + await assert_task_is_no_longer_present( + get_fastapi_long_running_manager(http_client.app), task_id, {} + ) async def _get_container_timestamps( @@ -199,7 +206,10 @@ def mock_environment( @pytest.fixture -async def app(app: FastAPI) -> AsyncIterable[FastAPI]: +async def app( + app: FastAPI, + fast_long_running_tasks_cancellation: None, +) -> AsyncIterable[FastAPI]: # add the client setup to the same application # this is only required for testing, in reality # this will be in a different process @@ -430,9 +440,32 @@ async def _assert_progress_finished( assert last_progress_message == ("finished", 1.0) +async def _perioduc_result_and_task_removed( + app: FastAPI, + http_client: HttpClient, + task_id: TaskId, + *, + progress_callback: ProgressCallback | None = None, +) -> Any | None: + try: + async with periodic_task_result( + client=http_client, + task_id=task_id, + task_timeout=_CREATE_SERVICE_CONTAINERS_TIMEOUT, + status_poll_interval=_FAST_STATUS_POLL, + progress_callback=progress_callback, + ) as result: + return result + finally: + await assert_task_is_no_longer_present( + get_fastapi_long_running_manager(app), task_id, {} + ) + + async def test_create_containers_task( httpx_async_client: AsyncClient, http_client: HttpClient, + app: FastAPI, compose_spec: str, mock_stop_heart_beat_task: AsyncMock, mock_metrics_params: CreateServiceMetricsAdditionalParams, @@ -448,16 +481,15 @@ async def create_progress( last_progress_message = (message, percent) print(message, percent) - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_create_service_containers( + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_create_service_containers( httpx_async_client, compose_spec, mock_metrics_params ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, progress_callback=create_progress, - ) as result: - assert shared_store.container_names == result + ) + assert shared_store.container_names == result await _assert_progress_finished(last_progress_message) @@ -465,6 +497,7 @@ async def create_progress( async def test_pull_user_servcices_docker_images( httpx_async_client: AsyncClient, http_client: HttpClient, + app: FastAPI, compose_spec: str, mock_stop_heart_beat_task: AsyncMock, mock_metrics_params: CreateServiceMetricsAdditionalParams, @@ -480,49 +513,44 @@ async def create_progress( last_progress_message = (message, percent) print(message, percent) - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_create_service_containers( + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_create_service_containers( httpx_async_client, compose_spec, mock_metrics_params ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, progress_callback=create_progress, - ) as result: - assert shared_store.container_names == result - + ) + assert shared_store.container_names == result await _assert_progress_finished(last_progress_message) - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_pull_user_servcices_docker_images( + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_pull_user_servcices_docker_images( httpx_async_client, compose_spec, mock_metrics_params ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, - progress_callback=_debug_progress, - ) as result: - assert result is None + progress_callback=create_progress, + ) + assert result is None await _assert_progress_finished(last_progress_message) async def test_create_containers_task_invalid_yaml_spec( httpx_async_client: AsyncClient, http_client: HttpClient, + app: FastAPI, mock_stop_heart_beat_task: AsyncMock, mock_metrics_params: CreateServiceMetricsAdditionalParams, ): with pytest.raises(InvalidComposeSpecError) as exec_info: - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_create_service_containers( + await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_create_service_containers( httpx_async_client, "", mock_metrics_params ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, - progress_callback=_debug_progress, - ): - pass + ) assert "Provided yaml is not valid" in f"{exec_info.value}" @@ -577,6 +605,7 @@ async def test_containers_down_after_starting( mock_ensure_read_permissions_on_user_service_data: None, httpx_async_client: AsyncClient, http_client: HttpClient, + app: FastAPI, compose_spec: str, mock_stop_heart_beat_task: AsyncMock, mock_metrics_params: CreateServiceMetricsAdditionalParams, @@ -585,68 +614,70 @@ async def test_containers_down_after_starting( mocker: MockerFixture, ): # start containers - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_create_service_containers( + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_create_service_containers( httpx_async_client, compose_spec, mock_metrics_params ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, progress_callback=_debug_progress, - ) as result: - assert shared_store.container_names == result + ) + assert shared_store.container_names == result # put down containers - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_docker_compose_down(httpx_async_client), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_docker_compose_down(httpx_async_client), progress_callback=_debug_progress, - ) as result: - assert result is None + ) + assert result is None async def test_containers_down_missing_spec( httpx_async_client: AsyncClient, http_client: HttpClient, + app: FastAPI, caplog_info_debug: pytest.LogCaptureFixture, ): - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_docker_compose_down(httpx_async_client), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_docker_compose_down(httpx_async_client), progress_callback=_debug_progress, - ) as result: - assert result is None + ) + assert result is None assert "No compose-spec was found" in caplog_info_debug.text async def test_container_restore_state( - httpx_async_client: AsyncClient, http_client: HttpClient, mock_data_manager: None + httpx_async_client: AsyncClient, + http_client: HttpClient, + app: FastAPI, + mock_data_manager: None, ): - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_state_restore(httpx_async_client), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_state_restore(httpx_async_client), progress_callback=_debug_progress, - ) as result: - assert isinstance(result, int) + ) + assert isinstance(result, int) async def test_container_save_state( - httpx_async_client: AsyncClient, http_client: HttpClient, mock_data_manager: None + httpx_async_client: AsyncClient, + http_client: HttpClient, + app: FastAPI, + mock_data_manager: None, ): - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_state_save(httpx_async_client), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_state_save(httpx_async_client), progress_callback=_debug_progress, - ) as result: - assert isinstance(result, int) + ) + assert isinstance(result, int) @pytest.mark.parametrize("inputs_pulling_enabled", [True, False]) @@ -661,57 +692,51 @@ async def test_container_pull_input_ports( if inputs_pulling_enabled: enable_inputs_pulling(app) - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_task_ports_inputs_pull( - httpx_async_client, mock_port_keys - ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_task_ports_inputs_pull(httpx_async_client, mock_port_keys), progress_callback=_debug_progress, - ) as result: - assert result == (42 if inputs_pulling_enabled else 0) + ) + assert result == (42 if inputs_pulling_enabled else 0) async def test_container_pull_output_ports( httpx_async_client: AsyncClient, http_client: HttpClient, + app: FastAPI, mock_port_keys: list[str] | None, mock_nodeports: None, ): - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_task_ports_outputs_pull( - httpx_async_client, mock_port_keys - ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_task_ports_outputs_pull(httpx_async_client, mock_port_keys), progress_callback=_debug_progress, - ) as result: - assert result == 42 + ) + assert result == 42 async def test_container_push_output_ports( httpx_async_client: AsyncClient, http_client: HttpClient, + app: FastAPI, mock_port_keys: list[str] | None, mock_nodeports: None, ): - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_task_ports_outputs_push( - httpx_async_client, mock_port_keys - ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_task_ports_outputs_push(httpx_async_client, mock_port_keys), progress_callback=_debug_progress, - ) as result: - assert result is None + ) + assert result is None async def test_container_push_output_ports_missing_node( httpx_async_client: AsyncClient, http_client: HttpClient, + app: FastAPI, mock_port_keys: list[str] | None, missing_node_uuid: str, mock_node_missing: None, @@ -721,16 +746,14 @@ async def test_container_push_output_ports_missing_node( await outputs_manager.port_key_content_changed(port_key) async def _test_code() -> None: - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_task_ports_outputs_push( + await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_task_ports_outputs_push( httpx_async_client, mock_port_keys ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, progress_callback=_debug_progress, - ): - pass + ) if not mock_port_keys: await _test_code() @@ -743,36 +766,34 @@ async def _test_code() -> None: async def test_containers_restart( httpx_async_client: AsyncClient, http_client: HttpClient, + app: FastAPI, compose_spec: str, mock_stop_heart_beat_task: AsyncMock, mock_metrics_params: CreateServiceMetricsAdditionalParams, shared_store: SharedStore, ): - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_create_service_containers( + container_names = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_create_service_containers( httpx_async_client, compose_spec, mock_metrics_params ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, progress_callback=_debug_progress, - ) as container_names: - assert shared_store.container_names == container_names - + ) + assert shared_store.container_names == container_names assert container_names container_timestamps_before = await _get_container_timestamps(container_names) - async with periodic_task_result( - client=http_client, - task_id=await _get_task_id_task_containers_restart( + result = await _perioduc_result_and_task_removed( + app, + http_client, + await _get_task_id_task_containers_restart( httpx_async_client, DEFAULT_COMMAND_TIMEOUT ), - task_timeout=CREATE_SERVICE_CONTAINERS_TIMEOUT, - status_poll_interval=FAST_STATUS_POLL, progress_callback=_debug_progress, - ) as result: - assert result is None + ) + assert result is None container_timestamps_after = await _get_container_timestamps(container_names) diff --git a/services/dynamic-sidecar/tests/unit/api/rpc/test__containers_long_running_tasks.py b/services/dynamic-sidecar/tests/unit/api/rpc/test__containers_long_running_tasks.py index 8973bec3e444..039660eb009b 100644 --- a/services/dynamic-sidecar/tests/unit/api/rpc/test__containers_long_running_tasks.py +++ b/services/dynamic-sidecar/tests/unit/api/rpc/test__containers_long_running_tasks.py @@ -1,4 +1,5 @@ # pylint: disable=no-member +# pylint: disable=protected-access # pylint: disable=redefined-outer-name # pylint: disable=too-many-arguments # pylint: disable=unused-argument @@ -29,10 +30,13 @@ from models_library.projects_nodes_io import NodeID from models_library.services_creation import CreateServiceMetricsAdditionalParams from pytest_mock.plugin import MockerFixture +from pytest_simcore.helpers.long_running_tasks import ( + assert_task_is_no_longer_present, + get_fastapi_long_running_manager, +) from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict, setenvs_from_dict from servicelib.fastapi.long_running_tasks._manager import FastAPILongRunningManager from servicelib.long_running_tasks import lrt_api -from servicelib.long_running_tasks.errors import TaskNotFoundError from servicelib.long_running_tasks.models import LRTNamespace, TaskId from servicelib.long_running_tasks.task import TaskRegistry from servicelib.rabbitmq import RabbitMQRPCClient @@ -517,6 +521,7 @@ async def test_create_containers_task_invalid_yaml_spec( async def test_same_task_id_is_returned_if_task_exists( mock_sidecar_lrts: None, rpc_client: RabbitMQRPCClient, + app: FastAPI, node_id: NodeID, lrt_namespace: LRTNamespace, mocker: MockerFixture, @@ -536,11 +541,10 @@ def _get_awaitable() -> Awaitable[TaskId]: ) async def _assert_task_removed(task_id: TaskId) -> None: - await lrt_api.remove_task( - rpc_client, lrt_namespace, {}, task_id, wait_for_removal=True + await lrt_api.remove_task(rpc_client, lrt_namespace, {}, task_id) + await assert_task_is_no_longer_present( + get_fastapi_long_running_manager(app), task_id, {} ) - with pytest.raises(TaskNotFoundError): - await lrt_api.get_task_status(rpc_client, lrt_namespace, {}, task_id) task_id = await _get_awaitable() assert task_id.endswith("unique") diff --git a/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py b/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py index af1eb7f5b988..6786f99bfcf7 100644 --- a/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py +++ b/services/dynamic-sidecar/tests/unit/test_api_rest_workflow_service_metrics.py @@ -36,6 +36,10 @@ from models_library.services_creation import CreateServiceMetricsAdditionalParams from pydantic import AnyHttpUrl, TypeAdapter from pytest_mock import MockerFixture +from pytest_simcore.helpers.long_running_tasks import ( + assert_task_is_no_longer_present, + get_fastapi_long_running_manager, +) from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict, setenvs_from_dict from servicelib.fastapi.long_running_tasks.client import ( HttpClient, @@ -92,6 +96,7 @@ def backend_url() -> AnyHttpUrl: @pytest.fixture async def mock_environment( + fast_long_running_tasks_cancellation: None, mock_postgres_check: None, mock_registry_service: AsyncMock, mock_environment: EnvVarsDict, @@ -351,14 +356,18 @@ async def test_user_services_fail_to_stop_or_save_data( # in case of manual intervention multiple stops will be sent _EXPECTED_STOP_MESSAGES = 4 for _ in range(_EXPECTED_STOP_MESSAGES): + task_id = await _get_task_id_docker_compose_down(httpx_async_client) with pytest.raises(TaskExceptionError): async with periodic_task_result( client=http_client, - task_id=await _get_task_id_docker_compose_down(httpx_async_client), + task_id=task_id, task_timeout=_CREATE_SERVICE_CONTAINERS_TIMEOUT, status_poll_interval=_FAST_STATUS_POLL, ): ... + await assert_task_is_no_longer_present( + get_fastapi_long_running_manager(app), task_id, {} + ) # Ensure messages arrive in the expected order resource_tracking_messages = _get_resource_tracking_messages( @@ -482,13 +491,17 @@ async def test_user_services_crash_when_running( # will be sent due to manual intervention _EXPECTED_STOP_MESSAGES = 4 for _ in range(_EXPECTED_STOP_MESSAGES): + task_id = await _get_task_id_docker_compose_down(httpx_async_client) async with periodic_task_result( client=http_client, - task_id=await _get_task_id_docker_compose_down(httpx_async_client), + task_id=task_id, task_timeout=_CREATE_SERVICE_CONTAINERS_TIMEOUT, status_poll_interval=_FAST_STATUS_POLL, ) as result: assert result is None + await assert_task_is_no_longer_present( + get_fastapi_long_running_manager(app), task_id, {} + ) resource_tracking_messages = _get_resource_tracking_messages( mock_post_rabbit_message