diff --git a/packages/models-library/src/models_library/api_schemas_long_running_tasks/tasks.py b/packages/models-library/src/models_library/api_schemas_long_running_tasks/tasks.py index 392d8841451..8f27127c71a 100644 --- a/packages/models-library/src/models_library/api_schemas_long_running_tasks/tasks.py +++ b/packages/models-library/src/models_library/api_schemas_long_running_tasks/tasks.py @@ -2,7 +2,8 @@ from datetime import datetime from typing import Any -from pydantic import BaseModel, field_validator +from common_library.exclude import Unset +from pydantic import BaseModel, ConfigDict, model_validator from .base import TaskId, TaskProgress @@ -20,15 +21,30 @@ class TaskResult(BaseModel): class TaskBase(BaseModel): task_id: TaskId - task_name: str + task_name: str | Unset = Unset.VALUE + + @model_validator(mode="after") + def try_populate_task_name_from_task_id(self) -> "TaskBase": + # NOTE: currently this model is used to validate tasks coming from + # the celery backend and form long_running_tasks + # 1. if a task comes from Celery, it will keep it's given name + # 2. if a task comes from long_running_tasks, it will extract it form + # the task_id, which looks like "{PREFIX}.{TASK_NAME}.UNIQUE|{UUID}" + + if self.task_id and self.task_name == Unset.VALUE: + parts = self.task_id.split(".") + if len(parts) > 1: + self.task_name = urllib.parse.unquote(parts[1]) + + if self.task_name == Unset.VALUE: + self.task_name = self.task_id + + return self + + model_config = ConfigDict(arbitrary_types_allowed=True) class TaskGet(TaskBase): status_href: str result_href: str abort_href: str - - @field_validator("task_name") - @classmethod - def unquote_str(cls, v) -> str: - return urllib.parse.unquote(v) diff --git a/packages/models-library/tests/test_api_schemas_long_running_tasks_tasks.py b/packages/models-library/tests/test_api_schemas_long_running_tasks_tasks.py new file mode 100644 index 00000000000..5acdf168ccc --- /dev/null +++ b/packages/models-library/tests/test_api_schemas_long_running_tasks_tasks.py @@ -0,0 +1,56 @@ +import pytest +from models_library.api_schemas_long_running_tasks.tasks import TaskGet +from pydantic import TypeAdapter + + +def _get_data_without_task_name(task_id: str) -> dict: + return { + "task_id": task_id, + "status_href": "", + "result_href": "", + "abort_href": "", + } + + +@pytest.mark.parametrize( + "data, expected_task_name", + [ + (_get_data_without_task_name("a.b.c.d"), "b"), + (_get_data_without_task_name("a.b.c"), "b"), + (_get_data_without_task_name("a.b"), "b"), + (_get_data_without_task_name("a"), "a"), + ], +) +def test_try_extract_task_name(data: dict, expected_task_name: str) -> None: + task_get = TaskGet(**data) + assert task_get.task_name == expected_task_name + + task_get = TypeAdapter(TaskGet).validate_python(data) + assert task_get.task_name == expected_task_name + + +def _get_data_with_task_name(task_id: str, task_name: str) -> dict: + return { + "task_id": task_id, + "task_name": task_name, + "status_href": "", + "result_href": "", + "abort_href": "", + } + + +@pytest.mark.parametrize( + "data, expected_task_name", + [ + (_get_data_with_task_name("a.b.c.d", "a_name"), "a_name"), + (_get_data_with_task_name("a.b.c", "a_name"), "a_name"), + (_get_data_with_task_name("a.b", "a_name"), "a_name"), + (_get_data_with_task_name("a", "a_name"), "a_name"), + ], +) +def test_task_name_is_provided(data: dict, expected_task_name: str) -> None: + task_get = TaskGet(**data) + assert task_get.task_name == expected_task_name + + task_get = TypeAdapter(TaskGet).validate_python(data) + assert task_get.task_name == expected_task_name diff --git a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py index 513203f6a1e..ec036dcb1c5 100644 --- a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py +++ b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from servicelib.aiohttp import status -from ...long_running_tasks import http_endpoint_responses +from ...long_running_tasks import lrt_api from ...long_running_tasks.models import TaskGet, TaskId, TaskStatus from ..requests_validation import parse_request_path_parameters_as from ..rest_responses import create_data_response @@ -24,12 +24,11 @@ async def list_tasks(request: web.Request) -> web.Response: [ TaskGet( task_id=t.task_id, - task_name=t.task_name, status_href=f"{request.app.router['get_task_status'].url_for(task_id=t.task_id)}", result_href=f"{request.app.router['get_task_result'].url_for(task_id=t.task_id)}", abort_href=f"{request.app.router['cancel_and_delete_task'].url_for(task_id=t.task_id)}", ) - for t in http_endpoint_responses.list_tasks( + for t in lrt_api.list_tasks( long_running_manager.tasks_manager, long_running_manager.get_task_context(request), ) @@ -42,7 +41,7 @@ async def get_task_status(request: web.Request) -> web.Response: path_params = parse_request_path_parameters_as(_PathParam, request) long_running_manager = get_long_running_manager(request.app) - task_status: TaskStatus = http_endpoint_responses.get_task_status( + task_status: TaskStatus = lrt_api.get_task_status( long_running_manager.tasks_manager, long_running_manager.get_task_context(request), path_params.task_id, @@ -56,7 +55,7 @@ async def get_task_result(request: web.Request) -> web.Response | Any: long_running_manager = get_long_running_manager(request.app) # NOTE: this might raise an exception that will be catched by the _error_handlers - return await http_endpoint_responses.get_task_result( + return await lrt_api.get_task_result( long_running_manager.tasks_manager, long_running_manager.get_task_context(request), path_params.task_id, @@ -68,7 +67,7 @@ async def cancel_and_delete_task(request: web.Request) -> web.Response: path_params = parse_request_path_parameters_as(_PathParam, request) long_running_manager = get_long_running_manager(request.app) - await http_endpoint_responses.remove_task( + await lrt_api.remove_task( long_running_manager.tasks_manager, long_running_manager.get_task_context(request), path_params.task_id, diff --git a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py index a68bb4a67b0..e8b2efad1dc 100644 --- a/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py +++ b/packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py @@ -9,12 +9,13 @@ from pydantic import AnyHttpUrl, TypeAdapter from ...aiohttp import status +from ...long_running_tasks import lrt_api from ...long_running_tasks.constants import ( DEFAULT_STALE_TASK_CHECK_INTERVAL, DEFAULT_STALE_TASK_DETECT_TIMEOUT, ) from ...long_running_tasks.models import TaskGet -from ...long_running_tasks.task import TaskContext, TaskProtocol, start_task +from ...long_running_tasks.task import RegisteredTaskName, TaskContext from ..typing_extension import Handler from . import _routes from ._constants import ( @@ -25,11 +26,11 @@ from ._manager import AiohttpLongRunningManager, get_long_running_manager -def no_ops_decorator(handler: Handler): +def _no_ops_decorator(handler: Handler): return handler -def no_task_context_decorator(handler: Handler): +def _no_task_context_decorator(handler: Handler): @wraps(handler) async def _wrap(request: web.Request): request[RQT_LONG_RUNNING_TASKS_CONTEXT_KEY] = {} @@ -45,7 +46,7 @@ def _create_task_name_from_request(request: web.Request) -> str: async def start_long_running_task( # NOTE: positional argument are suffixed with "_" to avoid name conflicts with "task_kwargs" keys request_: web.Request, - task_: TaskProtocol, + registerd_task_name: RegisteredTaskName, *, fire_and_forget: bool = False, task_context: TaskContext, @@ -55,9 +56,9 @@ async def start_long_running_task( task_name = _create_task_name_from_request(request_) task_id = None try: - task_id = start_task( + task_id = await lrt_api.start_task( long_running_manager.tasks_manager, - task_, + registerd_task_name, fire_and_forget=fire_and_forget, task_context=task_context, task_name=task_name, @@ -78,7 +79,6 @@ async def start_long_running_task( ) task_get = TaskGet( task_id=task_id, - task_name=task_name, status_href=f"{status_url}", result_href=f"{result_url}", abort_href=f"{abort_url}", @@ -121,8 +121,8 @@ def setup( app: web.Application, *, router_prefix: str, - handler_check_decorator: Callable = no_ops_decorator, - task_request_context_decorator: Callable = no_task_context_decorator, + handler_check_decorator: Callable = _no_ops_decorator, + task_request_context_decorator: Callable = _no_task_context_decorator, stale_task_check_interval: datetime.timedelta = DEFAULT_STALE_TASK_CHECK_INTERVAL, stale_task_detect_timeout: datetime.timedelta = DEFAULT_STALE_TASK_DETECT_TIMEOUT, ) -> None: diff --git a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py index 8b474c8add9..5f8a5e0a633 100644 --- a/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py +++ b/packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, Request, status -from ...long_running_tasks import http_endpoint_responses +from ...long_running_tasks import lrt_api from ...long_running_tasks.models import TaskGet, TaskId, TaskResult, TaskStatus from ..requests_decorators import cancel_on_disconnect from ._dependencies import get_long_running_manager @@ -23,14 +23,13 @@ async def list_tasks( return [ TaskGet( task_id=t.task_id, - task_name=t.task_name, status_href=str(request.url_for("get_task_status", task_id=t.task_id)), result_href=str(request.url_for("get_task_result", task_id=t.task_id)), abort_href=str( request.url_for("cancel_and_delete_task", task_id=t.task_id) ), ) - for t in http_endpoint_responses.list_tasks( + for t in lrt_api.list_tasks( long_running_manager.tasks_manager, task_context=None ) ] @@ -52,7 +51,7 @@ async def get_task_status( ], ) -> TaskStatus: assert request # nosec - return http_endpoint_responses.get_task_status( + return lrt_api.get_task_status( long_running_manager.tasks_manager, task_context=None, task_id=task_id ) @@ -75,7 +74,7 @@ async def get_task_result( ], ) -> TaskResult | Any: assert request # nosec - return await http_endpoint_responses.get_task_result( + return await lrt_api.get_task_result( long_running_manager.tasks_manager, task_context=None, task_id=task_id ) @@ -98,6 +97,6 @@ async def cancel_and_delete_task( ], ) -> None: assert request # nosec - await http_endpoint_responses.remove_task( + await lrt_api.remove_task( long_running_manager.tasks_manager, task_context=None, task_id=task_id ) diff --git a/packages/service-library/src/servicelib/logging_errors.py b/packages/service-library/src/servicelib/logging_errors.py index 3099aa9b863..938a4c3f62d 100644 --- a/packages/service-library/src/servicelib/logging_errors.py +++ b/packages/service-library/src/servicelib/logging_errors.py @@ -74,7 +74,7 @@ def create_troubleshootting_log_kwargs( ... except MyException as exc _logger.exception( - **create_troubleshotting_log_kwargs( + **create_troubleshootting_log_kwargs( user_error_msg=frontend_msg, error=exc, error_context={ diff --git a/packages/service-library/src/servicelib/long_running_tasks/errors.py b/packages/service-library/src/servicelib/long_running_tasks/errors.py index 33439c6436f..75e46da5b0c 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/errors.py +++ b/packages/service-library/src/servicelib/long_running_tasks/errors.py @@ -5,6 +5,13 @@ class BaseLongRunningError(OsparcErrorMixin, Exception): """base exception for this module""" +class TaskNotRegisteredError(BaseLongRunningError): + msg_template: str = ( + "no task with task_name='{task_name}' was found in the task registry. " + "Make sure it's registered before starting it." + ) + + class TaskAlreadyRunningError(BaseLongRunningError): msg_template: str = "{task_name} must be unique, found: '{managed_task}'" diff --git a/packages/service-library/src/servicelib/long_running_tasks/http_endpoint_responses.py b/packages/service-library/src/servicelib/long_running_tasks/http_endpoint_responses.py deleted file mode 100644 index 81ed78a2354..00000000000 --- a/packages/service-library/src/servicelib/long_running_tasks/http_endpoint_responses.py +++ /dev/null @@ -1,63 +0,0 @@ -import logging -from typing import Any - -from common_library.error_codes import create_error_code -from servicelib.logging_errors import create_troubleshootting_log_kwargs - -from .errors import TaskNotCompletedError, TaskNotFoundError -from .models import TaskBase, TaskId, TaskStatus -from .task import TaskContext, TasksManager, TrackedTask - -_logger = logging.getLogger(__name__) - - -def list_tasks( - tasks_manager: TasksManager, task_context: TaskContext | None -) -> list[TaskBase]: - tracked_tasks: list[TrackedTask] = tasks_manager.list_tasks( - with_task_context=task_context - ) - return [TaskBase(task_id=t.task_id, task_name=t.task_name) for t in tracked_tasks] - - -def get_task_status( - tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId -) -> TaskStatus: - return tasks_manager.get_task_status( - task_id=task_id, with_task_context=task_context - ) - - -async def get_task_result( - tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId -) -> Any: - try: - task_result = tasks_manager.get_task_result( - task_id, with_task_context=task_context - ) - await tasks_manager.remove_task( - task_id, with_task_context=task_context, reraise_errors=False - ) - return task_result - except (TaskNotFoundError, TaskNotCompletedError): - raise - except Exception as exc: - _logger.exception( - **create_troubleshootting_log_kwargs( - user_error_msg=f"{task_id=} raised an exception while getting its result", - error=exc, - error_code=create_error_code(exc), - error_context={"task_context": task_context, "task_id": task_id}, - ), - ) - # the task shall be removed in this case - await tasks_manager.remove_task( - task_id, with_task_context=task_context, reraise_errors=False - ) - raise - - -async def remove_task( - tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId -) -> None: - await tasks_manager.remove_task(task_id, with_task_context=task_context) diff --git a/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py b/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py new file mode 100644 index 00000000000..5e8c9191f7e --- /dev/null +++ b/packages/service-library/src/servicelib/long_running_tasks/lrt_api.py @@ -0,0 +1,107 @@ +import logging +from typing import Any + +from common_library.error_codes import create_error_code +from servicelib.logging_errors import create_troubleshootting_log_kwargs + +from .errors import TaskNotCompletedError, TaskNotFoundError +from .models import TaskBase, TaskId, TaskStatus +from .task import RegisteredTaskName, TaskContext, TasksManager + +_logger = logging.getLogger(__name__) + + +async def start_task( + tasks_manager: TasksManager, + registered_task_name: RegisteredTaskName, + *, + unique: bool = False, + task_context: TaskContext | None = None, + task_name: str | None = None, + fire_and_forget: bool = False, + **task_kwargs: Any, +) -> TaskId: + """ + Creates a background task from an async function. + + An asyncio task will be created out of it by injecting a `TaskProgress` as the first + positional argument and adding all `handler_kwargs` as named parameters. + + NOTE: the progress is automatically bounded between 0 and 1 + NOTE: the `task` name must be unique in the module, otherwise when using + the unique parameter is True, it will not be able to distinguish between + them. + + Args: + tasks_manager (TasksManager): the tasks manager + task (TaskProtocol): the tasks to be run in the background + unique (bool, optional): If True, then only one such named task may be run. Defaults to False. + task_context (Optional[TaskContext], optional): a task context storage can be retrieved during the task lifetime. Defaults to None. + task_name (Optional[str], optional): optional task name. Defaults to None. + fire_and_forget: if True, then the task will not be cancelled if the status is never called + + Raises: + TaskAlreadyRunningError: if unique is True, will raise if more than 1 such named task is started + + Returns: + TaskId: the task unique identifier + """ + return tasks_manager.start_task( + registered_task_name, + unique=unique, + task_context=task_context, + task_name=task_name, + fire_and_forget=fire_and_forget, + **task_kwargs, + ) + + +def list_tasks( + tasks_manager: TasksManager, task_context: TaskContext | None +) -> list[TaskBase]: + return tasks_manager.list_tasks(with_task_context=task_context) + + +def get_task_status( + tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId +) -> TaskStatus: + """returns the status of a task""" + return tasks_manager.get_task_status( + task_id=task_id, with_task_context=task_context + ) + + +async def get_task_result( + tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId +) -> Any: + try: + task_result = tasks_manager.get_task_result( + task_id, with_task_context=task_context + ) + await tasks_manager.remove_task( + task_id, with_task_context=task_context, reraise_errors=False + ) + return task_result + except (TaskNotFoundError, TaskNotCompletedError): + raise + except Exception as exc: + _logger.exception( + **create_troubleshootting_log_kwargs( + user_error_msg=f"{task_id=} raised an exception while getting its result", + error=exc, + error_code=create_error_code(exc), + error_context={"task_context": task_context, "task_id": task_id}, + ), + ) + # the task shall be removed in this case + await tasks_manager.remove_task( + task_id, with_task_context=task_context, reraise_errors=False + ) + raise + + +async def remove_task( + tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId +) -> None: + """removes / cancels a task""" + await tasks_manager.remove_task(task_id, with_task_context=task_context) diff --git a/packages/service-library/src/servicelib/long_running_tasks/models.py b/packages/service-library/src/servicelib/long_running_tasks/models.py index 37fc968568d..15ab97515af 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/models.py +++ b/packages/service-library/src/servicelib/long_running_tasks/models.py @@ -19,8 +19,6 @@ ) from pydantic import BaseModel, ConfigDict, Field, PositiveFloat -TaskName: TypeAlias = str - TaskType: TypeAlias = Callable[..., Coroutine[Any, Any, Any]] ProgressCallback: TypeAlias = Callable[ @@ -33,7 +31,6 @@ class TrackedTask(BaseModel): task_id: str task: Task - task_name: TaskName task_progress: TaskProgress # NOTE: this context lifetime is with the tracked task (similar to aiohttp storage concept) task_context: dict[str, Any] diff --git a/packages/service-library/src/servicelib/long_running_tasks/task.py b/packages/service-library/src/servicelib/long_running_tasks/task.py index a6007e2059a..7c6039f7b89 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/task.py +++ b/packages/service-library/src/servicelib/long_running_tasks/task.py @@ -6,7 +6,7 @@ import urllib.parse from collections import deque from contextlib import suppress -from typing import Any, Final, Protocol, TypeAlias +from typing import Any, ClassVar, Final, Protocol, TypeAlias from uuid import uuid4 from common_library.async_tools import cancel_wait_task @@ -21,47 +21,75 @@ TaskExceptionError, TaskNotCompletedError, TaskNotFoundError, + TaskNotRegisteredError, ) -from .models import TaskId, TaskName, TaskStatus, TrackedTask +from .models import TaskBase, TaskId, TaskStatus, TrackedTask _logger = logging.getLogger(__name__) + +# NOTE: for now only this one is used, in future it will be unqiue per service +# and this default will be removed and become mandatory +_DEFAULT_NAMESPACE: Final[str] = "lrt" + _CANCEL_TASK_TIMEOUT: Final[PositiveFloat] = datetime.timedelta( seconds=1 ).total_seconds() +RegisteredTaskName: TypeAlias = str +Namespace: TypeAlias = str TrackedTaskGroupDict: TypeAlias = dict[TaskId, TrackedTask] TaskContext: TypeAlias = dict[str, Any] +class TaskProtocol(Protocol): + async def __call__( + self, progress: TaskProgress, *args: Any, **kwargs: Any + ) -> Any: ... + + @property + def __name__(self) -> str: ... + + +class TaskRegistry: + REGISTERED_TASKS: ClassVar[dict[RegisteredTaskName, TaskProtocol]] = {} + + @classmethod + def register(cls, task: TaskProtocol) -> None: + cls.REGISTERED_TASKS[task.__name__] = task + + @classmethod + def unregister(cls, task: TaskProtocol) -> None: + if task.__name__ in cls.REGISTERED_TASKS: + del cls.REGISTERED_TASKS[task.__name__] + + async def _await_task(task: asyncio.Task) -> None: await task def _get_tasks_to_remove( - tasks_groups: dict[TaskName, TrackedTaskGroupDict], + tracked_tasks: TrackedTaskGroupDict, stale_task_detect_timeout_s: PositiveFloat, ) -> list[TaskId]: utc_now = datetime.datetime.now(tz=datetime.UTC) tasks_to_remove: list[TaskId] = [] - for tasks in tasks_groups.values(): - for task_id, tracked_task in tasks.items(): - if tracked_task.fire_and_forget: - continue - - if tracked_task.last_status_check is None: - # the task just added or never received a poll request - elapsed_from_start = (utc_now - tracked_task.started).seconds - if elapsed_from_start > stale_task_detect_timeout_s: - tasks_to_remove.append(task_id) - else: - # the task status was already queried by the client - elapsed_from_last_poll = ( - utc_now - tracked_task.last_status_check - ).seconds - if elapsed_from_last_poll > stale_task_detect_timeout_s: - tasks_to_remove.append(task_id) + + for task_id, tracked_task in tracked_tasks.items(): + if tracked_task.fire_and_forget: + continue + + if tracked_task.last_status_check is None: + # the task just added or never received a poll request + elapsed_from_start = (utc_now - tracked_task.started).seconds + if elapsed_from_start > stale_task_detect_timeout_s: + tasks_to_remove.append(task_id) + else: + # the task status was already queried by the client + elapsed_from_last_poll = (utc_now - tracked_task.last_status_check).seconds + if elapsed_from_last_poll > stale_task_detect_timeout_s: + tasks_to_remove.append(task_id) return tasks_to_remove @@ -74,9 +102,11 @@ def __init__( self, stale_task_check_interval: datetime.timedelta, stale_task_detect_timeout: datetime.timedelta, + namespace: Namespace = _DEFAULT_NAMESPACE, ): + self.namespace = namespace # Task groups: Every taskname maps to multiple asyncio.Task within TrackedTask model - self._tasks_groups: dict[TaskName, TrackedTaskGroupDict] = {} + self._tracked_tasks: TrackedTaskGroupDict = {} self.stale_task_check_interval = stale_task_check_interval self.stale_task_detect_timeout_s: PositiveFloat = ( @@ -95,9 +125,8 @@ async def setup(self) -> None: async def teardown(self) -> None: task_ids_to_remove: deque[TaskId] = deque() - for tasks_dict in self._tasks_groups.values(): - for tracked_task in tasks_dict.values(): - task_ids_to_remove.append(tracked_task.task_id) + for tracked_task in self._tracked_tasks.values(): + task_ids_to_remove.append(tracked_task.task_id) for task_id in task_ids_to_remove: # when closing we do not care about pending errors @@ -109,9 +138,6 @@ async def teardown(self) -> None: self._stale_tasks_monitor_task, max_delay=_CANCEL_TASK_TIMEOUT ) - def get_task_group(self, task_name: TaskName) -> TrackedTaskGroupDict: - return self._tasks_groups[task_name] - async def _stale_tasks_monitor_worker(self) -> None: """ A task is considered stale, if the task status is not queried @@ -130,7 +156,7 @@ async def _stale_tasks_monitor_worker(self) -> None: # will not be the case. tasks_to_remove = _get_tasks_to_remove( - self._tasks_groups, self.stale_task_detect_timeout_s + self._tracked_tasks, self.stale_task_detect_timeout_s ) # finally remove tasks and warn @@ -149,37 +175,20 @@ async def _stale_tasks_monitor_worker(self) -> None: task_id, with_task_context=None, reraise_errors=False ) - @staticmethod - def create_task_id(task_name: TaskName) -> str: - assert len(task_name) > 0 - return f"{task_name}.{uuid4()}" - - def is_task_running(self, task_name: TaskName) -> bool: - """returns True if a task named `task_name` is running""" - if task_name not in self._tasks_groups: - return False - - managed_tasks_ids = list(self._tasks_groups[task_name].keys()) - return len(managed_tasks_ids) > 0 - - def list_tasks(self, with_task_context: TaskContext | None) -> list[TrackedTask]: - tasks: list[TrackedTask] = [] - for task_group in self._tasks_groups.values(): - if not with_task_context: - tasks.extend(task_group.values()) - else: - tasks.extend( - [ - task - for task in task_group.values() - if task.task_context == with_task_context - ] - ) - return tasks + def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBase]: + if not with_task_context: + return [ + TaskBase(task_id=task.task_id) for task in self._tracked_tasks.values() + ] - def add_task( + return [ + TaskBase(task_id=task.task_id) + for task in self._tracked_tasks.values() + if task.task_context == with_task_context + ] + + def _add_task( self, - task_name: TaskName, task: asyncio.Task, task_progress: TaskProgress, task_context: TaskContext, @@ -187,33 +196,30 @@ def add_task( *, fire_and_forget: bool, ) -> TrackedTask: - if task_name not in self._tasks_groups: - self._tasks_groups[task_name] = {} tracked_task = TrackedTask( task_id=task_id, task=task, - task_name=task_name, task_progress=task_progress, task_context=task_context, fire_and_forget=fire_and_forget, ) - self._tasks_groups[task_name][task_id] = tracked_task + self._tracked_tasks[task_id] = tracked_task return tracked_task def _get_tracked_task( self, task_id: TaskId, with_task_context: TaskContext | None ) -> TrackedTask: - for tasks in self._tasks_groups.values(): - if task_id in tasks: - if with_task_context and ( - tasks[task_id].task_context != with_task_context - ): - raise TaskNotFoundError(task_id=task_id) - return tasks[task_id] + if task_id not in self._tracked_tasks: + raise TaskNotFoundError(task_id=task_id) + + task = self._tracked_tasks[task_id] - raise TaskNotFoundError(task_id=task_id) + if with_task_context and task.task_context != with_task_context: + raise TaskNotFoundError(task_id=task_id) + + return task def get_task_status( self, task_id: TaskId, with_task_context: TaskContext | None @@ -321,95 +327,65 @@ async def remove_task( tracked_task.task, task_id, reraise_errors=reraise_errors ) finally: - del self._tasks_groups[tracked_task.task_name][task_id] - - -class TaskProtocol(Protocol): - async def __call__( - self, progress: TaskProgress, *args: Any, **kwargs: Any - ) -> Any: ... - - @property - def __name__(self) -> str: ... + del self._tracked_tasks[task_id] + def _get_task_id(self, task_name: str, *, is_unique: bool) -> TaskId: + unique_part = "unique" if is_unique else f"{uuid4()}" + return f"{self.namespace}.{task_name}.{unique_part}" -def start_task( - tasks_manager: TasksManager, - task: TaskProtocol, - *, - unique: bool = False, - task_context: TaskContext | None = None, - task_name: str | None = None, - fire_and_forget: bool = False, - **task_kwargs: Any, -) -> TaskId: - """ - Creates a background task from an async function. - - An asyncio task will be created out of it by injecting a `TaskProgress` as the first - positional argument and adding all `handler_kwargs` as named parameters. - - NOTE: the progress is automatically bounded between 0 and 1 - NOTE: the `task` name must be unique in the module, otherwise when using - the unique parameter is True, it will not be able to distinguish between - them. - - Args: - tasks_manager (TasksManager): the tasks manager - task (TaskProtocol): the tasks to be run in the background - unique (bool, optional): If True, then only one such named task may be run. Defaults to False. - task_context (Optional[TaskContext], optional): a task context storage can be retrieved during the task lifetime. Defaults to None. - task_name (Optional[str], optional): optional task name. Defaults to None. - fire_and_forget: if True, then the task will not be cancelled if the status is never called + def start_task( + self, + registered_task_name: RegisteredTaskName, + *, + unique: bool, + task_context: TaskContext | None, + task_name: str | None, + fire_and_forget: bool, + **task_kwargs: Any, + ) -> TaskId: + if registered_task_name not in TaskRegistry.REGISTERED_TASKS: + raise TaskNotRegisteredError(task_name=registered_task_name) + + task = TaskRegistry.REGISTERED_TASKS[registered_task_name] + + # NOTE: If not task name is given, it will be composed of the handler's module and it's name + # to keep the urls shorter and more meaningful. + handler_module = inspect.getmodule(task) + handler_module_name = handler_module.__name__ if handler_module else "" + task_name = task_name or f"{handler_module_name}.{task.__name__}" + task_name = urllib.parse.quote(task_name, safe="") + + task_id = self._get_task_id(task_name, is_unique=unique) + + # only one unique task can be running + if unique and task_id in self._tracked_tasks: + raise TaskAlreadyRunningError( + task_name=task_name, managed_task=self._tracked_tasks[task_id] + ) - Raises: - TaskAlreadyRunningError: if unique is True, will raise if more than 1 such named task is started + task_progress = TaskProgress.create(task_id=task_id) - Returns: - TaskId: the task unique identifier - """ + # bind the task with progress 0 and 1 + async def _progress_task(progress: TaskProgress, handler: TaskProtocol): + progress.update(message="starting", percent=0) + try: + return await handler(progress, **task_kwargs) + finally: + progress.update(message="finished", percent=1) - # NOTE: If not task name is given, it will be composed of the handler's module and it's name - # to keep the urls shorter and more meaningful. - handler_module = inspect.getmodule(task) - handler_module_name = handler_module.__name__ if handler_module else "" - task_name = task_name or f"{handler_module_name}.{task.__name__}" - task_name = urllib.parse.quote(task_name, safe="") - - # only one unique task can be running - if unique and tasks_manager.is_task_running(task_name): - managed_tasks_ids = list(tasks_manager.get_task_group(task_name).keys()) - assert len(managed_tasks_ids) == 1 # nosec - managed_task: TrackedTask = tasks_manager.get_task_group(task_name)[ - managed_tasks_ids[0] - ] - raise TaskAlreadyRunningError(task_name=task_name, managed_task=managed_task) + async_task = asyncio.create_task( + _progress_task(task_progress, task), name=task_name + ) - task_id = tasks_manager.create_task_id(task_name=task_name) - task_progress = TaskProgress.create(task_id=task_id) + tracked_task = self._add_task( + task=async_task, + task_progress=task_progress, + task_context=task_context or {}, + fire_and_forget=fire_and_forget, + task_id=task_id, + ) - # bind the task with progress 0 and 1 - async def _progress_task(progress: TaskProgress, handler: TaskProtocol): - progress.update(message="starting", percent=0) - try: - return await handler(progress, **task_kwargs) - finally: - progress.update(message="finished", percent=1) - - async_task = asyncio.create_task( - _progress_task(task_progress, task), name=f"{task_name}" - ) - - tracked_task = tasks_manager.add_task( - task_name=task_name, - task=async_task, - task_progress=task_progress, - task_context=task_context or {}, - fire_and_forget=fire_and_forget, - task_id=task_id, - ) - - return tracked_task.task_id + return tracked_task.task_id __all__: tuple[str, ...] = ( diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py b/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py index 4f39c80be08..bd8d77da3ce 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/conftest.py @@ -19,7 +19,7 @@ TaskProgress, TaskStatus, ) -from servicelib.long_running_tasks.task import TaskContext +from servicelib.long_running_tasks.task import TaskContext, TaskRegistry from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type from tenacity.stop import stop_after_delay @@ -27,7 +27,7 @@ async def _string_list_task( - task_progress: TaskProgress, + progress: TaskProgress, num_strings: int, sleep_time: float, fail: bool, @@ -36,7 +36,7 @@ async def _string_list_task( for index in range(num_strings): generated_strings.append(f"{index}") await asyncio.sleep(sleep_time) - task_progress.update(message="generated item", percent=index / num_strings) + progress.update(message="generated item", percent=index / num_strings) if fail: msg = "We were asked to fail!!" raise RuntimeError(msg) @@ -47,6 +47,9 @@ async def _string_list_task( ) +TaskRegistry.register(_string_list_task) + + @pytest.fixture def task_context(faker: Faker) -> TaskContext: return {"user_id": faker.pyint(), "product": faker.pystr()} @@ -73,7 +76,7 @@ async def generate_list_strings(request: web.Request) -> web.Response: query_params = parse_request_query_parameters_as(_LongTaskQueryParams, request) return await long_running_tasks.server.start_long_running_task( request, - _string_list_task, + _string_list_task.__name__, num_strings=query_params.num_strings, sleep_time=query_params.sleep_time, fail=query_params.fail, diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py index 939ba7e330a..b904d766d10 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py @@ -97,7 +97,7 @@ async def test_workflow( # now get the result result_url = client.app.router["get_task_result"].url_for(task_id=task_id) result = await client.get(f"{result_url}") - task_result, error = await assert_status(result, status.HTTP_201_CREATED) + task_result, error = await assert_status(result, status.HTTP_200_OK) assert task_result assert not error assert task_result == [f"{x}" for x in range(10)] @@ -224,7 +224,9 @@ async def test_list_tasks( # the task name is properly formatted assert all( - task.task_name == "POST /long_running_task:start?num_strings=10&sleep_time=0.2" + task.task_name.startswith( + "POST /long_running_task:start?num_strings=10&sleep_time=" + ) for task in list_of_tasks ) # now wait for them to finish diff --git a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py index 20f8dbf657f..8f7ff5efd23 100644 --- a/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py +++ b/packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py @@ -151,7 +151,7 @@ async def test_get_task_result( await assert_status(resp, status.HTTP_404_NOT_FOUND) # calling with context should find the task resp = await client_with_task_context.get(f"{result_url.with_query(task_context)}") - await assert_status(resp, status.HTTP_201_CREATED) + await assert_status(resp, status.HTTP_200_OK) async def test_cancel_task( diff --git a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py index 07809f928b0..5ebbdb744e0 100644 --- a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py +++ b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py @@ -12,7 +12,7 @@ import asyncio import json from collections.abc import AsyncIterator, Awaitable, Callable -from typing import Final +from typing import Annotated, Final import pytest from asgi_lifespan import LifespanManager @@ -25,13 +25,14 @@ get_long_running_manager, ) from servicelib.fastapi.long_running_tasks.server import setup as setup_server +from servicelib.long_running_tasks import lrt_api from servicelib.long_running_tasks.models import ( TaskGet, TaskId, TaskProgress, TaskStatus, ) -from servicelib.long_running_tasks.task import TaskContext, start_task +from servicelib.long_running_tasks.task import TaskContext, TaskRegistry from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type from tenacity.stop import stop_after_delay @@ -42,7 +43,7 @@ async def _string_list_task( - task_progress: TaskProgress, + progress: TaskProgress, num_strings: int, sleep_time: float, fail: bool, @@ -51,7 +52,7 @@ async def _string_list_task( for index in range(num_strings): generated_strings.append(f"{index}") await asyncio.sleep(sleep_time) - task_progress.update(message="generated item", percent=index / num_strings) + progress.update(message="generated item", percent=index / num_strings) if fail: msg = "We were asked to fail!!" raise RuntimeError(msg) @@ -59,6 +60,9 @@ async def _string_list_task( return generated_strings +TaskRegistry.register(_string_list_task) + + @pytest.fixture def server_routes() -> APIRouter: routes = APIRouter() @@ -69,19 +73,19 @@ def server_routes() -> APIRouter: async def create_string_list_task( num_strings: int, sleep_time: float, + long_running_manager: Annotated[ + FastAPILongRunningManager, Depends(get_long_running_manager) + ], + *, fail: bool = False, - long_running_manager: FastAPILongRunningManager = Depends( - get_long_running_manager - ), ) -> TaskId: - task_id = start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - _string_list_task, + _string_list_task.__name__, num_strings=num_strings, sleep_time=sleep_time, fail=fail, ) - return task_id return routes diff --git a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py index 32d884d1ff8..85b59eb1a35 100644 --- a/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py +++ b/packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py @@ -16,6 +16,7 @@ from servicelib.fastapi.long_running_tasks.client import setup as setup_client from servicelib.fastapi.long_running_tasks.server import get_long_running_manager from servicelib.fastapi.long_running_tasks.server import setup as setup_server +from servicelib.long_running_tasks import lrt_api from servicelib.long_running_tasks.errors import ( TaskClientTimeoutError, TaskExceptionError, @@ -26,7 +27,7 @@ TaskId, TaskProgress, ) -from servicelib.long_running_tasks.task import start_task +from servicelib.long_running_tasks.task import TaskRegistry TASK_SLEEP_INTERVAL: Final[PositiveFloat] = 0.1 @@ -40,17 +41,25 @@ async def _assert_task_removed( assert result.status_code == status.HTTP_404_NOT_FOUND -async def a_test_task(task_progress: TaskProgress) -> int: +async def a_test_task(progress: TaskProgress) -> int: + _ = progress await asyncio.sleep(TASK_SLEEP_INTERVAL) return 42 -async def a_failing_test_task(task_progress: TaskProgress) -> None: +TaskRegistry.register(a_test_task) + + +async def a_failing_test_task(progress: TaskProgress) -> None: + _ = progress await asyncio.sleep(TASK_SLEEP_INTERVAL) msg = "I am failing as requested" raise RuntimeError(msg) +TaskRegistry.register(a_failing_test_task) + + @pytest.fixture def user_routes() -> APIRouter: router = APIRouter() @@ -61,7 +70,9 @@ async def create_task_user_defined_route( FastAPILongRunningManager, Depends(get_long_running_manager) ], ) -> TaskId: - return start_task(long_running_manager.tasks_manager, task=a_test_task) + return await lrt_api.start_task( + long_running_manager.tasks_manager, a_test_task.__name__ + ) @router.get("/api/failing", status_code=status.HTTP_200_OK) async def create_task_which_fails( @@ -69,7 +80,9 @@ async def create_task_which_fails( FastAPILongRunningManager, Depends(get_long_running_manager) ], ) -> TaskId: - return start_task(long_running_manager.tasks_manager, task=a_failing_test_task) + return await lrt_api.start_task( + long_running_manager.tasks_manager, a_failing_test_task.__name__ + ) return router diff --git a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py index 6e3ec1522e2..566be002808 100644 --- a/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py +++ b/packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py @@ -12,18 +12,16 @@ import pytest from faker import Faker +from servicelib.long_running_tasks import lrt_api from servicelib.long_running_tasks.errors import ( TaskAlreadyRunningError, TaskCancelledError, TaskNotCompletedError, TaskNotFoundError, + TaskNotRegisteredError, ) -from servicelib.long_running_tasks.models import ( - ProgressPercent, - TaskProgress, - TaskStatus, -) -from servicelib.long_running_tasks.task import TasksManager, start_task +from servicelib.long_running_tasks.models import TaskProgress, TaskStatus +from servicelib.long_running_tasks.task import TaskRegistry, TasksManager from tenacity import TryAgain from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type @@ -39,14 +37,14 @@ async def a_background_task( - task_progress: TaskProgress, + progress: TaskProgress, raise_when_finished: bool, total_sleep: int, ) -> int: """sleeps and raises an error or returns 42""" for i in range(total_sleep): await asyncio.sleep(1) - task_progress.update(percent=ProgressPercent((i + 1) / total_sleep)) + progress.update(percent=(i + 1) / total_sleep) if raise_when_finished: msg = "raised this error as instructed" raise RuntimeError(msg) @@ -54,17 +52,21 @@ async def a_background_task( return 42 -async def fast_background_task(task_progress: TaskProgress) -> int: +async def fast_background_task(progress: TaskProgress) -> int: """this task does nothing and returns a constant""" return 42 -async def failing_background_task(task_progress: TaskProgress): +async def failing_background_task(progress: TaskProgress): """this task does nothing and returns a constant""" msg = "failing asap" raise RuntimeError(msg) +TaskRegistry.register(a_background_task) +TaskRegistry.register(fast_background_task) +TaskRegistry.register(failing_background_task) + TEST_CHECK_STALE_INTERVAL_S: Final[float] = 1 @@ -83,9 +85,9 @@ async def tasks_manager() -> AsyncIterator[TasksManager]: async def test_task_is_auto_removed( tasks_manager: TasksManager, check_task_presence_before: bool ): - task_id = start_task( + task_id = await lrt_api.start_task( tasks_manager, - a_background_task, + a_background_task.__name__, raise_when_finished=False, total_sleep=10 * TEST_CHECK_STALE_INTERVAL_S, ) @@ -99,10 +101,9 @@ async def test_task_is_auto_removed( # meaning no calls via the manager methods are received async for attempt in AsyncRetrying(**_RETRY_PARAMS): with attempt: - for tasks in tasks_manager._tasks_groups.values(): # noqa: SLF001 - if task_id in tasks: - msg = "wait till no element is found any longer" - raise TryAgain(msg) + if task_id in tasks_manager._tracked_tasks: # noqa: SLF001 + msg = "wait till no element is found any longer" + raise TryAgain(msg) with pytest.raises(TaskNotFoundError): tasks_manager.get_task_status(task_id, with_task_context=None) @@ -111,9 +112,9 @@ async def test_task_is_auto_removed( async def test_checked_task_is_not_auto_removed(tasks_manager: TasksManager): - task_id = start_task( + task_id = await lrt_api.start_task( tasks_manager, - a_background_task, + a_background_task.__name__, raise_when_finished=False, total_sleep=5 * TEST_CHECK_STALE_INTERVAL_S, ) @@ -126,9 +127,9 @@ async def test_checked_task_is_not_auto_removed(tasks_manager: TasksManager): async def test_fire_and_forget_task_is_not_auto_removed(tasks_manager: TasksManager): - task_id = start_task( + task_id = await lrt_api.start_task( tasks_manager, - a_background_task, + a_background_task.__name__, raise_when_finished=False, total_sleep=5 * TEST_CHECK_STALE_INTERVAL_S, fire_and_forget=True, @@ -145,9 +146,9 @@ async def test_fire_and_forget_task_is_not_auto_removed(tasks_manager: TasksMana async def test_get_result_of_unfinished_task_raises(tasks_manager: TasksManager): - task_id = start_task( + task_id = await lrt_api.start_task( tasks_manager, - a_background_task, + a_background_task.__name__, raise_when_finished=False, total_sleep=5 * TEST_CHECK_STALE_INTERVAL_S, ) @@ -156,35 +157,45 @@ async def test_get_result_of_unfinished_task_raises(tasks_manager: TasksManager) async def test_unique_task_already_running(tasks_manager: TasksManager): - async def unique_task(task_progress: TaskProgress): + async def unique_task(progress: TaskProgress): + _ = progress await asyncio.sleep(1) - start_task(tasks_manager=tasks_manager, task=unique_task, unique=True) + TaskRegistry.register(unique_task) + + await lrt_api.start_task(tasks_manager, unique_task.__name__, unique=True) # ensure unique running task regardless of how many times it gets started with pytest.raises(TaskAlreadyRunningError) as exec_info: - start_task(tasks_manager=tasks_manager, task=unique_task, unique=True) + await lrt_api.start_task(tasks_manager, unique_task.__name__, unique=True) assert "must be unique, found: " in f"{exec_info.value}" + TaskRegistry.unregister(unique_task) + async def test_start_multiple_not_unique_tasks(tasks_manager: TasksManager): - async def not_unique_task(task_progress: TaskProgress): + async def not_unique_task(progress: TaskProgress): await asyncio.sleep(1) + TaskRegistry.register(not_unique_task) + for _ in range(5): - start_task(tasks_manager=tasks_manager, task=not_unique_task) + await lrt_api.start_task(tasks_manager, not_unique_task.__name__) + TaskRegistry.unregister(not_unique_task) -def test_get_task_id(faker): - obj1 = TasksManager.create_task_id(faker.word()) # noqa: SLF001 - obj2 = TasksManager.create_task_id(faker.word()) # noqa: SLF001 + +@pytest.mark.parametrize("is_unique", [True, False]) +def test_get_task_id(tasks_manager: TasksManager, faker: Faker, is_unique: bool): + obj1 = tasks_manager._get_task_id(faker.word(), is_unique=is_unique) # noqa: SLF001 + obj2 = tasks_manager._get_task_id(faker.word(), is_unique=is_unique) # noqa: SLF001 assert obj1 != obj2 async def test_get_status(tasks_manager: TasksManager): - task_id = start_task( - tasks_manager=tasks_manager, - task=a_background_task, + task_id = await lrt_api.start_task( + tasks_manager, + a_background_task.__name__, raise_when_finished=False, total_sleep=10, ) @@ -192,7 +203,7 @@ async def test_get_status(tasks_manager: TasksManager): assert isinstance(task_status, TaskStatus) assert task_status.task_progress.message == "" assert task_status.task_progress.percent == 0.0 - assert task_status.done == False + assert task_status.done is False assert isinstance(task_status.started, datetime) @@ -203,7 +214,7 @@ async def test_get_status_missing(tasks_manager: TasksManager): async def test_get_result(tasks_manager: TasksManager): - task_id = start_task(tasks_manager=tasks_manager, task=fast_background_task) + task_id = await lrt_api.start_task(tasks_manager, fast_background_task.__name__) await asyncio.sleep(0.1) result = tasks_manager.get_task_result(task_id, with_task_context=None) assert result == 42 @@ -216,7 +227,7 @@ async def test_get_result_missing(tasks_manager: TasksManager): async def test_get_result_finished_with_error(tasks_manager: TasksManager): - task_id = start_task(tasks_manager=tasks_manager, task=failing_background_task) + task_id = await lrt_api.start_task(tasks_manager, failing_background_task.__name__) # wait for result async for attempt in AsyncRetrying(**_RETRY_PARAMS): with attempt: @@ -229,9 +240,9 @@ async def test_get_result_finished_with_error(tasks_manager: TasksManager): async def test_get_result_task_was_cancelled_multiple_times( tasks_manager: TasksManager, ): - task_id = start_task( - tasks_manager=tasks_manager, - task=a_background_task, + task_id = await lrt_api.start_task( + tasks_manager, + a_background_task.__name__, raise_when_finished=False, total_sleep=10, ) @@ -245,9 +256,9 @@ async def test_get_result_task_was_cancelled_multiple_times( async def test_remove_task(tasks_manager: TasksManager): - task_id = start_task( - tasks_manager=tasks_manager, - task=a_background_task, + task_id = await lrt_api.start_task( + tasks_manager, + a_background_task.__name__, raise_when_finished=False, total_sleep=10, ) @@ -261,9 +272,9 @@ async def test_remove_task(tasks_manager: TasksManager): async def test_remove_task_with_task_context(tasks_manager: TasksManager): TASK_CONTEXT = {"some_context": "some_value"} - task_id = start_task( - tasks_manager=tasks_manager, - task=a_background_task, + task_id = await lrt_api.start_task( + tasks_manager, + a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_context=TASK_CONTEXT, @@ -294,9 +305,9 @@ async def test_remove_unknown_task(tasks_manager: TasksManager): async def test_cancel_task_with_task_context(tasks_manager: TasksManager): TASK_CONTEXT = {"some_context": "some_value"} - task_id = start_task( - tasks_manager=tasks_manager, - task=a_background_task, + task_id = await lrt_api.start_task( + tasks_manager, + a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_context=TASK_CONTEXT, @@ -321,9 +332,9 @@ async def test_list_tasks(tasks_manager: TasksManager): task_ids = [] for _ in range(NUM_TASKS): task_ids.append( # noqa: PERF401 - start_task( - tasks_manager=tasks_manager, - task=a_background_task, + await lrt_api.start_task( + tasks_manager, + a_background_task.__name__, raise_when_finished=False, total_sleep=10, ) @@ -337,22 +348,22 @@ async def test_list_tasks(tasks_manager: TasksManager): async def test_list_tasks_filtering(tasks_manager: TasksManager): - start_task( - tasks_manager=tasks_manager, - task=a_background_task, + await lrt_api.start_task( + tasks_manager, + a_background_task.__name__, raise_when_finished=False, total_sleep=10, ) - start_task( - tasks_manager=tasks_manager, - task=a_background_task, + await lrt_api.start_task( + tasks_manager, + a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_context={"user_id": 213}, ) - start_task( - tasks_manager=tasks_manager, - task=a_background_task, + await lrt_api.start_task( + tasks_manager, + a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_context={"user_id": 213, "product": "osparc"}, @@ -379,11 +390,16 @@ async def test_list_tasks_filtering(tasks_manager: TasksManager): async def test_define_task_name(tasks_manager: TasksManager, faker: Faker): task_name = faker.name() - task_id = start_task( - tasks_manager=tasks_manager, - task=a_background_task, + task_id = await lrt_api.start_task( + tasks_manager, + a_background_task.__name__, raise_when_finished=False, total_sleep=10, task_name=task_name, ) - assert task_id.startswith(urllib.parse.quote(task_name, safe="")) + assert urllib.parse.quote(task_name, safe="") in task_id + + +async def test_start_not_registered_task(tasks_manager: TasksManager): + with pytest.raises(TaskNotRegisteredError): + await lrt_api.start_task(tasks_manager, "not_registered_task") diff --git a/services/api-server/openapi.json b/services/api-server/openapi.json index 40b9594f3e5..4598a9b45d9 100644 --- a/services/api-server/openapi.json +++ b/services/api-server/openapi.json @@ -11480,7 +11480,6 @@ "type": "object", "required": [ "task_id", - "task_name", "status_href", "result_href", "abort_href" diff --git a/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py b/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py index 322ee491051..bbec37d8d45 100644 --- a/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/api/routes/dynamic_scheduler.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, PositiveInt from servicelib.fastapi.long_running_tasks._manager import FastAPILongRunningManager from servicelib.fastapi.long_running_tasks.server import get_long_running_manager +from servicelib.long_running_tasks import lrt_api from servicelib.long_running_tasks.errors import TaskAlreadyRunningError from servicelib.long_running_tasks.models import ( ProgressMessage, @@ -13,7 +14,7 @@ TaskId, TaskProgress, ) -from servicelib.long_running_tasks.task import start_task +from servicelib.long_running_tasks.task import TaskRegistry from tenacity import retry from tenacity.before_sleep import before_sleep_log from tenacity.retry import retry_if_result @@ -100,26 +101,30 @@ async def delete_service_containers( ], ): async def _task_remove_service_containers( - task_progress: TaskProgress, node_uuid: NodeID + progress: TaskProgress, node_uuid: NodeID ) -> None: async def _progress_callback( message: ProgressMessage, percent: ProgressPercent | None, _: TaskId ) -> None: - task_progress.update(message=message, percent=percent) + progress.update(message=message, percent=percent) await dynamic_sidecars_scheduler.remove_service_containers( node_uuid=node_uuid, progress_callback=_progress_callback ) + TaskRegistry.register(_task_remove_service_containers) + try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=_task_remove_service_containers, # type: ignore[arg-type] + _task_remove_service_containers.__name__, unique=True, node_uuid=node_uuid, ) except TaskAlreadyRunningError as e: raise HTTPException(status.HTTP_409_CONFLICT, detail=f"{e}") from e + finally: + TaskRegistry.unregister(_task_remove_service_containers) @router.get( @@ -160,27 +165,31 @@ async def save_service_state( ], ): async def _task_save_service_state( - task_progress: TaskProgress, + progress: TaskProgress, node_uuid: NodeID, ) -> None: async def _progress_callback( message: ProgressMessage, percent: ProgressPercent | None, _: TaskId ) -> None: - task_progress.update(message=message, percent=percent) + progress.update(message=message, percent=percent) await dynamic_sidecars_scheduler.save_service_state( node_uuid=node_uuid, progress_callback=_progress_callback ) + TaskRegistry.register(_task_save_service_state) + try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=_task_save_service_state, # type: ignore[arg-type] + _task_save_service_state.__name__, unique=True, node_uuid=node_uuid, ) except TaskAlreadyRunningError as e: raise HTTPException(status.HTTP_409_CONFLICT, detail=f"{e}") from e + finally: + TaskRegistry.unregister(_task_save_service_state) @router.post( @@ -204,26 +213,30 @@ async def push_service_outputs( ], ): async def _task_push_service_outputs( - task_progress: TaskProgress, node_uuid: NodeID + progress: TaskProgress, node_uuid: NodeID ) -> None: async def _progress_callback( message: ProgressMessage, percent: ProgressPercent | None, _: TaskId ) -> None: - task_progress.update(message=message, percent=percent) + progress.update(message=message, percent=percent) await dynamic_sidecars_scheduler.push_service_outputs( node_uuid=node_uuid, progress_callback=_progress_callback ) + TaskRegistry.register(_task_push_service_outputs) + try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=_task_push_service_outputs, # type: ignore[arg-type] + _task_push_service_outputs.__name__, unique=True, node_uuid=node_uuid, ) except TaskAlreadyRunningError as e: raise HTTPException(status.HTTP_409_CONFLICT, detail=f"{e}") from e + finally: + TaskRegistry.unregister(_task_push_service_outputs) @router.delete( @@ -247,21 +260,25 @@ async def delete_service_docker_resources( ], ): async def _task_cleanup_service_docker_resources( - task_progress: TaskProgress, node_uuid: NodeID + progress: TaskProgress, node_uuid: NodeID ) -> None: await dynamic_sidecars_scheduler.remove_service_sidecar_proxy_docker_networks_and_volumes( - task_progress=task_progress, node_uuid=node_uuid + task_progress=progress, node_uuid=node_uuid ) + TaskRegistry.register(_task_cleanup_service_docker_resources) + try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=_task_cleanup_service_docker_resources, # type: ignore[arg-type] + _task_cleanup_service_docker_resources.__name__, unique=True, node_uuid=node_uuid, ) except TaskAlreadyRunningError as e: raise HTTPException(status.HTTP_409_CONFLICT, detail=f"{e}") from e + finally: + TaskRegistry.unregister(_task_cleanup_service_docker_resources) @router.post( diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/docker_states.py b/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/docker_states.py index afd44dc0f59..7978d6d57a2 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/docker_states.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dynamic_sidecar/docker_states.py @@ -1,6 +1,7 @@ """ States from Docker Tasks and docker Containers are mapped to ServiceState. """ + import logging from models_library.generated_models.docker_rest_api import ContainerState @@ -8,7 +9,7 @@ from ...models.dynamic_services_scheduler import DockerContainerInspect -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) # For all available task states SEE # https://docs.docker.com/engine/swarm/how-swarm-mode-works/swarm-task-states/ @@ -62,7 +63,7 @@ def extract_task_state(task_status: dict[str, str]) -> tuple[ServiceState, str]: - last_task_error_msg = task_status["Err"] if "Err" in task_status else "" + last_task_error_msg = task_status.get("Err", "") task_state = _TASK_STATE_TO_SERVICE_STATE[task_status["State"]] return (task_state, last_task_error_msg) @@ -89,7 +90,7 @@ def extract_containers_minimum_statuses( the lowest (considered worst) state will be forwarded to the frontend. `ServiceState` defines the order of the states. """ - logger.info("containers_inspect=%s", containers_inspect) + _logger.debug("containers_inspect=%s", containers_inspect) remapped_service_statuses = { index: _extract_container_status(value.container_state) for index, value in enumerate(containers_inspect) diff --git a/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py b/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py index 41abda858bb..37f2b7b4965 100644 --- a/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py +++ b/services/director-v2/tests/unit/test_api_route_dynamic_scheduler.py @@ -201,8 +201,9 @@ async def test_409_response( ) assert response.status_code == status.HTTP_202_ACCEPTED task_id = response.json() - assert task_id.startswith( - f"simcore_service_director_v2.api.routes.dynamic_scheduler.{task_name}." + assert ( + f"simcore_service_director_v2.api.routes.dynamic_scheduler.{task_name}" + in task_id ) response = client.request( diff --git a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/api/rest/containers_long_running_tasks.py b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/api/rest/containers_long_running_tasks.py index 17ae8d9187f..7e9fdb3d0b8 100644 --- a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/api/rest/containers_long_running_tasks.py +++ b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/api/rest/containers_long_running_tasks.py @@ -5,9 +5,9 @@ from servicelib.fastapi.long_running_tasks._manager import FastAPILongRunningManager from servicelib.fastapi.long_running_tasks.server import get_long_running_manager from servicelib.fastapi.requests_decorators import cancel_on_disconnect +from servicelib.long_running_tasks import lrt_api from servicelib.long_running_tasks.errors import TaskAlreadyRunningError from servicelib.long_running_tasks.models import TaskId -from servicelib.long_running_tasks.task import start_task from ...core.settings import ApplicationSettings from ...models.schemas.application_health import ApplicationHealth @@ -58,9 +58,9 @@ async def pull_user_servcices_docker_images( assert request # nosec try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=task_pull_user_servcices_docker_images, + task_pull_user_servcices_docker_images.__name__, unique=True, app=app, shared_store=shared_store, @@ -99,9 +99,9 @@ async def create_service_containers_task( # pylint: disable=too-many-arguments assert request # nosec try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=task_create_service_containers, + task_create_service_containers.__name__, unique=True, settings=settings, containers_create=containers_create, @@ -133,9 +133,9 @@ async def runs_docker_compose_down_task( assert request # nosec try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=task_runs_docker_compose_down, + task_runs_docker_compose_down.__name__, unique=True, app=app, shared_store=shared_store, @@ -165,9 +165,9 @@ async def state_restore_task( assert request # nosec try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=task_restore_state, + task_restore_state.__name__, unique=True, settings=settings, mounted_volumes=mounted_volumes, @@ -196,9 +196,9 @@ async def state_save_task( assert request # nosec try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=task_save_state, + task_save_state.__name__, unique=True, settings=settings, mounted_volumes=mounted_volumes, @@ -229,9 +229,9 @@ async def ports_inputs_pull_task( assert request # nosec try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=task_ports_inputs_pull, + task_ports_inputs_pull.__name__, unique=True, port_keys=port_keys, mounted_volumes=mounted_volumes, @@ -262,9 +262,9 @@ async def ports_outputs_pull_task( assert request # nosec try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=task_ports_outputs_pull, + task_ports_outputs_pull.__name__, unique=True, port_keys=port_keys, mounted_volumes=mounted_volumes, @@ -292,9 +292,9 @@ async def ports_outputs_push_task( assert request # nosec try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=task_ports_outputs_push, + task_ports_outputs_push.__name__, unique=True, outputs_manager=outputs_manager, app=app, @@ -322,9 +322,9 @@ async def containers_restart_task( assert request # nosec try: - return start_task( + return await lrt_api.start_task( long_running_manager.tasks_manager, - task=task_containers_restart, + task_containers_restart.__name__, unique=True, app=app, settings=settings, diff --git a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py index 42412376d08..f14f90be1ef 100644 --- a/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py +++ b/services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/long_running_tasks.py @@ -6,16 +6,14 @@ from typing import Final from fastapi import FastAPI -from models_library.api_schemas_long_running_tasks.base import ( - ProgressPercent, - TaskProgress, -) +from models_library.api_schemas_long_running_tasks.base import TaskProgress from models_library.generated_models.docker_rest_api import ContainerState from models_library.rabbitmq_messages import ProgressType, SimcorePlatformStatus from models_library.service_settings_labels import LegacyState from pydantic import PositiveInt from servicelib.file_utils import log_directory_changes from servicelib.logging_utils import log_context +from servicelib.long_running_tasks.task import TaskRegistry from servicelib.progress_bar import ProgressBarData from servicelib.utils import logged_gather from simcore_sdk.node_data import data_manager @@ -151,13 +149,11 @@ async def task_pull_user_servcices_docker_images( ) -> None: assert shared_store.compose_spec # nosec - progress.update(message="started pulling user services", percent=ProgressPercent(0)) + progress.update(message="started pulling user services", percent=0) await docker_compose_pull(app, shared_store.compose_spec) - progress.update( - message="finished pulling user services", percent=ProgressPercent(1) - ) + progress.update(message="finished pulling user services", percent=1) async def task_create_service_containers( @@ -168,7 +164,7 @@ async def task_create_service_containers( app: FastAPI, application_health: ApplicationHealth, ) -> list[str]: - progress.update(message="validating service spec", percent=ProgressPercent(0)) + progress.update(message="validating service spec", percent=0) assert shared_store.compose_spec # nosec @@ -194,18 +190,14 @@ async def task_create_service_containers( _raise_for_errors(result, "rm") await progress_bar.update() - progress.update( - message="creating and starting containers", percent=ProgressPercent(0.90) - ) + progress.update(message="creating and starting containers", percent=0.90) await post_sidecar_log_message( app, "starting service containers", log_level=logging.INFO ) await _retry_docker_compose_create(shared_store.compose_spec, settings) await progress_bar.update() - progress.update( - message="ensure containers are started", percent=ProgressPercent(0.95) - ) + progress.update(message="ensure containers are started", percent=0.95) compose_start_result = await _retry_docker_compose_start( shared_store.compose_spec, settings ) @@ -288,9 +280,7 @@ async def _send_resource_tracking_stop(platform_status: SimcorePlatformStatus): await send_service_stopped(app, simcore_platform_status) try: - progress.update( - message="running docker-compose-down", percent=ProgressPercent(0.1) - ) + progress.update(message="running docker-compose-down", percent=0.1) await run_before_shutdown_actions( shared_store, settings.DY_SIDECAR_CALLBACKS_MAPPING.before_shutdown @@ -303,13 +293,11 @@ async def _send_resource_tracking_stop(platform_status: SimcorePlatformStatus): result = await _retry_docker_compose_down(shared_store.compose_spec, settings) _raise_for_errors(result, "down") - progress.update(message="stopping logs", percent=ProgressPercent(0.9)) + progress.update(message="stopping logs", percent=0.9) for container_name in shared_store.container_names: await stop_log_fetching(app, container_name) - progress.update( - message="removing pending resources", percent=ProgressPercent(0.95) - ) + progress.update(message="removing pending resources", percent=0.95) result = await docker_compose_rm(shared_store.compose_spec, settings) _raise_for_errors(result, "rm") except Exception: @@ -326,7 +314,7 @@ async def _send_resource_tracking_stop(platform_status: SimcorePlatformStatus): async with shared_store: shared_store.compose_spec = None shared_store.container_names = [] - progress.update(message="done", percent=ProgressPercent(0.99)) + progress.update(message="done", percent=0.99) def _get_satate_folders_size(paths: list[Path]) -> int: @@ -389,7 +377,7 @@ async def task_restore_state( # NOTE: this implies that the legacy format will always be decompressed # until it is not removed. - progress.update(message="Downloading state", percent=ProgressPercent(0.05)) + progress.update(message="Downloading state", percent=0.05) state_paths = list(mounted_volumes.disk_state_paths_iter()) await post_sidecar_log_message( app, @@ -419,7 +407,7 @@ async def task_restore_state( await post_sidecar_log_message( app, "Finished state downloading", log_level=logging.INFO ) - progress.update(message="state restored", percent=ProgressPercent(0.99)) + progress.update(message="state restored", percent=0.99) return _get_satate_folders_size(state_paths) @@ -459,7 +447,7 @@ async def task_save_state( If a legacy archive is detected, it will be removed after saving the new format. """ - progress.update(message="starting state save", percent=ProgressPercent(0.0)) + progress.update(message="starting state save", percent=0.0) state_paths = list(mounted_volumes.disk_state_paths_iter()) async with ProgressBarData( num_steps=len(state_paths), @@ -485,7 +473,7 @@ async def task_save_state( ) await post_sidecar_log_message(app, "Finished state saving", log_level=logging.INFO) - progress.update(message="finished state saving", percent=ProgressPercent(0.99)) + progress.update(message="finished state saving", percent=0.99) return _get_satate_folders_size(state_paths) @@ -503,12 +491,12 @@ async def task_ports_inputs_pull( _logger.info("Received request to pull inputs but was ignored") return 0 - progress.update(message="starting inputs pulling", percent=ProgressPercent(0.0)) + progress.update(message="starting inputs pulling", percent=0.0) port_keys = [] if port_keys is None else port_keys await post_sidecar_log_message( app, f"Pulling inputs for {port_keys}", log_level=logging.INFO ) - progress.update(message="pulling inputs", percent=ProgressPercent(0.1)) + progress.update(message="pulling inputs", percent=0.1) async with ProgressBarData( num_steps=1, progress_report_cb=functools.partial( @@ -539,7 +527,7 @@ async def task_ports_inputs_pull( await post_sidecar_log_message( app, "Finished pulling inputs", log_level=logging.INFO ) - progress.update(message="finished inputs pulling", percent=ProgressPercent(0.99)) + progress.update(message="finished inputs pulling", percent=0.99) return int(transferred_bytes) @@ -549,7 +537,7 @@ async def task_ports_outputs_pull( mounted_volumes: MountedVolumes, app: FastAPI, ) -> int: - progress.update(message="starting outputs pulling", percent=ProgressPercent(0.0)) + progress.update(message="starting outputs pulling", percent=0.0) port_keys = [] if port_keys is None else port_keys await post_sidecar_log_message( app, f"Pulling output for {port_keys}", log_level=logging.INFO @@ -576,14 +564,14 @@ async def task_ports_outputs_pull( await post_sidecar_log_message( app, "Finished pulling outputs", log_level=logging.INFO ) - progress.update(message="finished outputs pulling", percent=ProgressPercent(0.99)) + progress.update(message="finished outputs pulling", percent=0.99) return int(transferred_bytes) async def task_ports_outputs_push( progress: TaskProgress, outputs_manager: OutputsManager, app: FastAPI ) -> None: - progress.update(message="starting outputs pushing", percent=ProgressPercent(0.0)) + progress.update(message="starting outputs pushing", percent=0.0) await post_sidecar_log_message( app, f"waiting for outputs {outputs_manager.outputs_context.file_type_port_keys} to be pushed", @@ -595,7 +583,7 @@ async def task_ports_outputs_push( await post_sidecar_log_message( app, "finished outputs pushing", log_level=logging.INFO ) - progress.update(message="finished outputs pushing", percent=ProgressPercent(0.99)) + progress.update(message="finished outputs pushing", percent=0.99) async def task_containers_restart( @@ -610,9 +598,7 @@ async def task_containers_restart( # or some other state, the service will get shutdown, to prevent this # blocking status while containers are being restarted. async with app.state.container_restart_lock: - progress.update( - message="starting containers restart", percent=ProgressPercent(0.0) - ) + progress.update(message="starting containers restart", percent=0.0) if shared_store.compose_spec is None: msg = "No spec for docker compose command was found" raise RuntimeError(msg) @@ -620,20 +606,34 @@ async def task_containers_restart( for container_name in shared_store.container_names: await stop_log_fetching(app, container_name) - progress.update(message="stopped log fetching", percent=ProgressPercent(0.1)) + progress.update(message="stopped log fetching", percent=0.1) result = await docker_compose_restart(shared_store.compose_spec, settings) _raise_for_errors(result, "restart") - progress.update(message="containers restarted", percent=ProgressPercent(0.8)) + progress.update(message="containers restarted", percent=0.8) for container_name in shared_store.container_names: await start_log_fetching(app, container_name) - progress.update(message="started log fetching", percent=ProgressPercent(0.9)) + progress.update(message="started log fetching", percent=0.9) await post_sidecar_log_message( app, "Service was restarted please reload the UI", log_level=logging.INFO ) await post_event_reload_iframe(app) - progress.update(message="started log fetching", percent=ProgressPercent(0.99)) + progress.update(message="started log fetching", percent=0.99) + + +for task in ( + task_pull_user_servcices_docker_images, + task_create_service_containers, + task_runs_docker_compose_down, + task_restore_state, + task_save_state, + task_ports_inputs_pull, + task_ports_outputs_pull, + task_ports_outputs_push, + task_containers_restart, +): + TaskRegistry.register(task) diff --git a/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py b/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py index b2b1d005266..932a381fac6 100644 --- a/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py +++ b/services/dynamic-sidecar/tests/unit/test_api_rest_containers_long_running_tasks.py @@ -32,6 +32,7 @@ from servicelib.fastapi.long_running_tasks.client import setup as client_setup from servicelib.long_running_tasks.errors import TaskExceptionError from servicelib.long_running_tasks.models import TaskId +from servicelib.long_running_tasks.task import TaskRegistry from simcore_sdk.node_ports_common.exceptions import NodeNotFound from simcore_service_dynamic_sidecar._meta import API_VTAG from simcore_service_dynamic_sidecar.api.rest import containers_long_running_tasks @@ -73,6 +74,8 @@ def mock_tasks(mocker: MockerFixture) -> Iterator[None]: async def _just_log_task(*args, **kwargs) -> None: print(f"Called mocked function with {args}, {kwargs}") + TaskRegistry.register(_just_log_task) + # searching by name since all start with _task tasks_names = [ x[0] @@ -87,6 +90,8 @@ async def _just_log_task(*args, **kwargs) -> None: yield None + TaskRegistry.unregister(_just_log_task) + @asynccontextmanager async def auto_remove_task(client: Client, task_id: TaskId) -> AsyncIterator[None]: @@ -506,13 +511,15 @@ def _get_awaitable() -> Awaitable: with mock_tasks(mocker): task_id = await _get_awaitable() + assert task_id.endswith("unique") async with auto_remove_task(client, task_id): assert await _get_awaitable() == task_id # since the previous task was already removed it is again possible - # to create a task + # to create a task and it will share the same task_id new_task_id = await _get_awaitable() - assert new_task_id != task_id + assert new_task_id.endswith("unique") + assert new_task_id == task_id async with auto_remove_task(client, task_id): pass diff --git a/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml b/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml index ce122ca58cc..759adedfa02 100644 --- a/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml +++ b/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml @@ -17294,7 +17294,6 @@ components: type: object required: - task_id - - task_name - status_href - result_href - abort_href diff --git a/services/web/server/src/simcore_service_webserver/projects/_controller/nodes_rest.py b/services/web/server/src/simcore_service_webserver/projects/_controller/nodes_rest.py index b42be8077c3..2e2c6eb12d3 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_controller/nodes_rest.py +++ b/services/web/server/src/simcore_service_webserver/projects/_controller/nodes_rest.py @@ -43,6 +43,7 @@ X_SIMCORE_USER_AGENT, ) from servicelib.long_running_tasks.models import TaskProgress +from servicelib.long_running_tasks.task import TaskRegistry from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON from servicelib.rabbitmq import RPCServerError from servicelib.rabbitmq.rpc_interfaces.dynamic_scheduler.errors import ( @@ -289,11 +290,12 @@ async def start_node(request: web.Request) -> web.Response: async def _stop_dynamic_service_task( - _task_progress: TaskProgress, + progress: TaskProgress, *, app: web.Application, dynamic_service_stop: DynamicServiceStop, ): + _ = progress # NOTE: _handle_project_nodes_exceptions only decorate handlers try: await dynamic_scheduler_service.stop_dynamic_service( @@ -310,6 +312,9 @@ async def _stop_dynamic_service_task( return web.json_response(status=status.HTTP_204_NO_CONTENT) +TaskRegistry.register(_stop_dynamic_service_task) + + @routes.post( f"/{VTAG}/projects/{{project_id}}/nodes/{{node_id}}:stop", name="stop_node" ) @@ -334,7 +339,7 @@ async def stop_node(request: web.Request) -> web.Response: return await start_long_running_task( request, - _stop_dynamic_service_task, # type: ignore[arg-type] # @GitHK, @pcrespov this one I don't know how to fix + _stop_dynamic_service_task.__name__, task_context=jsonable_encoder(req_ctx), # task arguments from here on --- app=request.app, diff --git a/services/web/server/src/simcore_service_webserver/projects/_controller/projects_rest.py b/services/web/server/src/simcore_service_webserver/projects/_controller/projects_rest.py index 0db92ca68c5..5edaa322b8b 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_controller/projects_rest.py +++ b/services/web/server/src/simcore_service_webserver/projects/_controller/projects_rest.py @@ -97,7 +97,7 @@ async def create_project(request: web.Request): return await start_long_running_task( request, - _crud_api_create.create_project, # type: ignore[arg-type] # @GitHK, @pcrespov this one I don't know how to fix + _crud_api_create.create_project.__name__, fire_and_forget=True, task_context=jsonable_encoder(req_ctx), # arguments @@ -414,7 +414,7 @@ async def clone_project(request: web.Request): return await start_long_running_task( request, - _crud_api_create.create_project, # type: ignore[arg-type] # @GitHK, @pcrespov this one I don't know how to fix + _crud_api_create.create_project.__name__, fire_and_forget=True, task_context=jsonable_encoder(req_ctx), # arguments diff --git a/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py b/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py index 1116e8f2709..506312ef10b 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py +++ b/services/web/server/src/simcore_service_webserver/projects/_crud_api_create.py @@ -18,6 +18,7 @@ from models_library.workspaces import UserWorkspaceWithAccessRights from pydantic import TypeAdapter from servicelib.long_running_tasks.models import TaskProgress +from servicelib.long_running_tasks.task import TaskRegistry from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON from servicelib.redis import with_project_locked from servicelib.rest_constants import RESPONSE_MODEL_POLICY @@ -250,7 +251,7 @@ async def _compose_project_data( async def create_project( # pylint: disable=too-many-arguments,too-many-branches,too-many-statements # noqa: C901, PLR0913 - task_progress: TaskProgress, + progress: TaskProgress, *, request: web.Request, new_project_was_hidden_before_data_was_copied: bool, @@ -299,7 +300,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche copy_file_coro = None project_nodes = None try: - task_progress.update(message="creating new study...") + progress.update(message="creating new study...") workspace_id = None folder_id = None @@ -337,7 +338,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche src_project_uuid=from_study, as_template=as_template, deep_copy=copy_data, - task_progress=task_progress, + task_progress=progress, ) if project_node_coro: project_nodes = await project_node_coro @@ -387,7 +388,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche parent_project_uuid=parent_project_uuid, parent_node_id=parent_node_id, ) - task_progress.update() + progress.update() # 3.2 move project to proper folder if folder_id: @@ -415,7 +416,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche await dynamic_scheduler_service.update_projects_networks( request.app, project_id=ProjectID(new_project["uuid"]) ) - task_progress.update() + progress.update() # This is a new project and every new graph needs to be reflected in the pipeline tables await director_v2_service.create_or_update_pipeline( @@ -436,7 +437,7 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche is_template=as_template, app=request.app, ) - task_progress.update() + progress.update() # Adds permalink await update_or_pop_permalink_in_project(request, new_project) @@ -518,3 +519,6 @@ async def create_project( # pylint: disable=too-many-arguments,too-many-branche simcore_user_agent=simcore_user_agent, ) raise + + +TaskRegistry.register(create_project) diff --git a/services/web/server/src/simcore_service_webserver/storage/_rest.py b/services/web/server/src/simcore_service_webserver/storage/_rest.py index 4fe428ea8a6..6836b5f80dc 100644 --- a/services/web/server/src/simcore_service_webserver/storage/_rest.py +++ b/services/web/server/src/simcore_service_webserver/storage/_rest.py @@ -183,7 +183,6 @@ def _create_data_response_from_async_job( return create_data_response( TaskGet( task_id=async_job_id, - task_name=async_job_id, status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=async_job_id)))}", abort_href=f"{request.url.with_path(str(request.app.router['cancel_async_job'].url_for(task_id=async_job_id)))}", result_href=f"{request.url.with_path(str(request.app.router['get_async_job_result'].url_for(task_id=async_job_id)))}", @@ -504,7 +503,6 @@ def allow_only_simcore(cls, v: int) -> int: return create_data_response( TaskGet( task_id=_job_id, - task_name=_job_id, status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=_job_id)))}", abort_href=f"{request.url.with_path(str(request.app.router['cancel_async_job'].url_for(task_id=_job_id)))}", result_href=f"{request.url.with_path(str(request.app.router['get_async_job_result'].url_for(task_id=_job_id)))}", diff --git a/services/web/server/src/simcore_service_webserver/tasks/_rest.py b/services/web/server/src/simcore_service_webserver/tasks/_rest.py index 13176af833b..bc4e8d08f4f 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_rest.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_rest.py @@ -28,7 +28,7 @@ parse_request_path_parameters_as, ) from servicelib.aiohttp.rest_responses import create_data_response -from servicelib.long_running_tasks import http_endpoint_responses +from servicelib.long_running_tasks import lrt_api from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs from .._meta import API_VTAG @@ -58,7 +58,7 @@ @webserver_request_context_decorator async def get_async_jobs(request: web.Request) -> web.Response: inprocess_long_running_manager = get_long_running_manager(request.app) - inprocess_tracked_tasks = http_endpoint_responses.list_tasks( + inprocess_tracked_tasks = lrt_api.list_tasks( inprocess_long_running_manager.tasks_manager, inprocess_long_running_manager.get_task_context(request), ) @@ -91,7 +91,6 @@ async def get_async_jobs(request: web.Request) -> web.Response: + [ TaskGet( task_id=f"{task.task_id}", - task_name=task.task_name, status_href=f"{request.app.router['get_task_status'].url_for(task_id=task.task_id)}", abort_href=f"{request.app.router['cancel_and_delete_task'].url_for(task_id=task.task_id)}", result_href=f"{request.app.router['get_task_result'].url_for(task_id=task.task_id)}",