diff --git a/packages/service-library/src/servicelib/progress_bar.py b/packages/service-library/src/servicelib/progress_bar.py index 5d7f40421d31..1f65e44790be 100644 --- a/packages/service-library/src/servicelib/progress_bar.py +++ b/packages/service-library/src/servicelib/progress_bar.py @@ -22,14 +22,12 @@ @runtime_checkable class AsyncReportCB(Protocol): - async def __call__(self, report: ProgressReport) -> None: - ... + async def __call__(self, report: ProgressReport) -> None: ... @runtime_checkable class ReportCB(Protocol): - def __call__(self, report: ProgressReport) -> None: - ... + def __call__(self, report: ProgressReport) -> None: ... def _normalize_weights(steps: int, weights: list[float]) -> list[float]: @@ -88,7 +86,7 @@ async def main_fct(): progress_unit: ProgressUnit | None = None progress_report_cb: AsyncReportCB | ReportCB | None = None _current_steps: float = _INITIAL_VALUE - _currnet_attempt: int = 0 + _current_attempt: int = 0 _children: list["ProgressBarData"] = field(default_factory=list) _parent: Optional["ProgressBarData"] = None _continuous_value_lock: asyncio.Lock = field(init=False) @@ -148,7 +146,7 @@ async def _report_external(self, value: float) -> None: # NOTE: here we convert back to actual value since this is possibly weighted actual_value=value * self.num_steps, total=self.num_steps, - attempt=self._currnet_attempt, + attempt=self._current_attempt, unit=self.progress_unit, message=self.compute_report_message_stuct(), ), @@ -200,7 +198,7 @@ async def update(self, steps: float = 1) -> None: await self._report_external(new_progress_value) def reset(self) -> None: - self._currnet_attempt += 1 + self._current_attempt += 1 self._current_steps = _INITIAL_VALUE self._last_report_value = _INITIAL_VALUE diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py index db81b8d9f58d..c50799bda05b 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py @@ -1,5 +1,10 @@ -from typing import Final +import datetime +import logging +from asyncio import CancelledError +from collections.abc import AsyncGenerator, Awaitable +from typing import Any, Final +from attr import dataclass from models_library.api_schemas_rpc_async_jobs.async_jobs import ( AsyncJobGet, AsyncJobId, @@ -9,12 +14,25 @@ ) from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace from pydantic import NonNegativeInt, TypeAdapter +from tenacity import ( + AsyncRetrying, + TryAgain, + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_delay, + wait_fixed, + wait_random_exponential, +) +from ....rabbitmq import RemoteMethodNotRegisteredError from ... import RabbitMQRPCClient _DEFAULT_TIMEOUT_S: Final[NonNegativeInt] = 30 _RPC_METHOD_NAME_ADAPTER = TypeAdapter(RPCMethodName) +_DEFAULT_POLL_INTERVAL_S: Final[float] = 0.1 +_logger = logging.getLogger(__name__) async def cancel( @@ -103,3 +121,117 @@ async def submit( ) assert isinstance(_result, AsyncJobGet) # nosec return _result + + +_DEFAULT_RPC_RETRY_POLICY: dict[str, Any] = { + "retry": retry_if_exception_type(RemoteMethodNotRegisteredError), + "wait": wait_random_exponential(max=20), + "stop": stop_after_delay(60), + "reraise": True, + "before_sleep": before_sleep_log(_logger, logging.INFO), +} + + +@retry(**_DEFAULT_RPC_RETRY_POLICY) +async def _wait_for_completion( + rabbitmq_rpc_client: RabbitMQRPCClient, + *, + rpc_namespace: RPCNamespace, + method_name: RPCMethodName, + job_id: AsyncJobId, + job_id_data: AsyncJobNameData, + client_timeout: datetime.timedelta, +) -> AsyncGenerator[AsyncJobStatus, None]: + try: + async for attempt in AsyncRetrying( + stop=stop_after_delay(client_timeout.total_seconds()), + reraise=True, + retry=retry_if_exception_type(TryAgain), + before_sleep=before_sleep_log(_logger, logging.DEBUG), + wait=wait_fixed(_DEFAULT_POLL_INTERVAL_S), + ): + with attempt: + job_status = await status( + rabbitmq_rpc_client, + rpc_namespace=rpc_namespace, + job_id=job_id, + job_id_data=job_id_data, + ) + yield job_status + if not job_status.done: + msg = f"{job_status.job_id=}: '{job_status.progress=}'" + raise TryAgain(msg) # noqa: TRY301 + + except TryAgain as exc: + # this is a timeout + msg = f"Async job {job_id=}, calling to '{method_name}' timed-out after {client_timeout}" + raise TimeoutError(msg) from exc + + +@dataclass(frozen=True) +class AsyncJobComposedResult: + status: AsyncJobStatus + _result: Awaitable[Any] | None = None + + @property + def done(self) -> bool: + return self._result is not None + + async def result(self) -> Any: + if not self._result: + msg = "No result ready!" + raise ValueError(msg) + return await self._result + + +async def submit_and_wait( + rabbitmq_rpc_client: RabbitMQRPCClient, + *, + rpc_namespace: RPCNamespace, + method_name: str, + job_id_data: AsyncJobNameData, + client_timeout: datetime.timedelta, + **kwargs, +) -> AsyncGenerator[AsyncJobComposedResult, None]: + async_job_rpc_get = None + try: + async_job_rpc_get = await submit( + rabbitmq_rpc_client, + rpc_namespace=rpc_namespace, + method_name=method_name, + job_id_data=job_id_data, + **kwargs, + ) + job_status: AsyncJobStatus | None = None + async for job_status in _wait_for_completion( + rabbitmq_rpc_client, + rpc_namespace=rpc_namespace, + method_name=method_name, + job_id=async_job_rpc_get.job_id, + job_id_data=job_id_data, + client_timeout=client_timeout, + ): + assert job_status is not None # nosec + yield AsyncJobComposedResult(job_status) + if job_status: + yield AsyncJobComposedResult( + job_status, + result( + rabbitmq_rpc_client, + rpc_namespace=rpc_namespace, + job_id=async_job_rpc_get.job_id, + job_id_data=job_id_data, + ), + ) + except (TimeoutError, CancelledError) as error: + if async_job_rpc_get is not None: + try: + await cancel( + rabbitmq_rpc_client, + rpc_namespace=rpc_namespace, + job_id=async_job_rpc_get.job_id, + job_id_data=job_id_data, + ) + except Exception as exc: + raise exc from error + raise diff --git a/packages/service-library/tests/rabbitmq/conftest.py b/packages/service-library/tests/rabbitmq/conftest.py index 79f1c0cdb32d..e107d848daa3 100644 --- a/packages/service-library/tests/rabbitmq/conftest.py +++ b/packages/service-library/tests/rabbitmq/conftest.py @@ -1,10 +1,31 @@ -from collections.abc import AsyncIterator, Callable, Coroutine +from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine from typing import cast import aiodocker import arrow import pytest from faker import Faker +from models_library.rabbitmq_basic_types import RPCNamespace +from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient + + +@pytest.fixture +async def rpc_client( + rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]], +) -> RabbitMQRPCClient: + return await rabbitmq_rpc_client("pytest_rpc_client") + + +@pytest.fixture +async def rpc_server( + rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]], +) -> RabbitMQRPCClient: + return await rabbitmq_rpc_client("pytest_rpc_server") + + +@pytest.fixture +def namespace() -> RPCNamespace: + return RPCNamespace.from_entries({f"test{i}": f"test{i}" for i in range(8)}) @pytest.fixture(autouse=True) diff --git a/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc.py b/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc.py index 46588de6e87d..40417c4d4c34 100644 --- a/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc.py +++ b/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc.py @@ -2,7 +2,7 @@ # pylint:disable=unused-argument import asyncio -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable from typing import Any, Final import pytest @@ -23,11 +23,6 @@ MULTIPLE_REQUESTS_COUNT: Final[NonNegativeInt] = 100 -@pytest.fixture -def namespace() -> RPCNamespace: - return RPCNamespace.from_entries({f"test{i}": f"test{i}" for i in range(8)}) - - async def add_me(*, x: Any, y: Any) -> Any: return x + y # NOTE: types are not enforced @@ -49,20 +44,6 @@ def __add__(self, other: "CustomClass") -> "CustomClass": return CustomClass(x=self.x + other.x, y=self.y + other.y) -@pytest.fixture -async def rpc_client( - rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]], -) -> RabbitMQRPCClient: - return await rabbitmq_rpc_client("pytest_rpc_client") - - -@pytest.fixture -async def rpc_server( - rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]], -) -> RabbitMQRPCClient: - return await rabbitmq_rpc_client("pytest_rpc_server") - - @pytest.mark.parametrize( "x,y,expected_result,expected_type", [ diff --git a/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py b/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py new file mode 100644 index 000000000000..2522764a4ca9 --- /dev/null +++ b/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py @@ -0,0 +1,251 @@ +import asyncio +import datetime +from collections.abc import AsyncIterator +from dataclasses import dataclass, field + +import pytest +from faker import Faker +from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobGet, + AsyncJobId, + AsyncJobNameData, + AsyncJobResult, + AsyncJobStatus, +) +from models_library.api_schemas_rpc_async_jobs.exceptions import JobMissingError +from models_library.progress_bar import ProgressReport +from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace +from pydantic import TypeAdapter +from servicelib.async_utils import cancel_wait_task +from servicelib.rabbitmq import RabbitMQRPCClient, RemoteMethodNotRegisteredError +from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import ( + list_jobs, + submit, + submit_and_wait, +) + +pytest_simcore_core_services_selection = [ + "rabbit", +] + + +@pytest.fixture +def method_name(faker: Faker) -> RPCMethodName: + return TypeAdapter(RPCMethodName).validate_python(faker.word()) + + +@pytest.fixture +def job_id_data(faker: Faker) -> AsyncJobNameData: + return AsyncJobNameData( + user_id=faker.pyint(min_value=1), + product_name=faker.word(), + ) + + +@pytest.fixture +def job_id(faker: Faker) -> AsyncJobId: + return faker.uuid4(cast_to=None) + + +@pytest.fixture +async def async_job_rpc_server( # noqa: C901 + rpc_server: RabbitMQRPCClient, + faker: Faker, + namespace: RPCNamespace, + method_name: RPCMethodName, +) -> AsyncIterator[None]: + async def _slow_task() -> None: + await asyncio.sleep(2) + + @dataclass + class FakeServer: + tasks: list[asyncio.Task] = field(default_factory=list) + + def _get_task(self, job_id: AsyncJobId) -> asyncio.Task: + for task in self.tasks: + if task.get_name() == f"{job_id}": + return task + raise JobMissingError(job_id=f"{job_id}") + + async def status( + self, job_id: AsyncJobId, job_id_data: AsyncJobNameData + ) -> AsyncJobStatus: + assert job_id_data + task = self._get_task(job_id) + return AsyncJobStatus( + job_id=job_id, + progress=ProgressReport(actual_value=1 if task.done() else 0.3), + done=task.done(), + ) + + async def cancel( + self, job_id: AsyncJobId, job_id_data: AsyncJobNameData + ) -> None: + assert job_id + assert job_id_data + task = self._get_task(job_id) + task.cancel() + + async def result( + self, job_id: AsyncJobId, job_id_data: AsyncJobNameData + ) -> AsyncJobResult: + assert job_id_data + task = self._get_task(job_id) + assert task.done() + return AsyncJobResult( + result={ + "data": task.result(), + "job_id": job_id, + "job_id_data": job_id_data, + } + ) + + async def list_jobs( + self, filter_: str, job_id_data: AsyncJobNameData + ) -> list[AsyncJobGet]: + assert job_id_data + assert filter_ is not None + + return [ + AsyncJobGet( + job_id=TypeAdapter(AsyncJobId).validate_python(t.get_name()) + ) + for t in self.tasks + ] + + async def submit(self, job_id_data: AsyncJobNameData) -> AsyncJobGet: + assert job_id_data + job_id = faker.uuid4(cast_to=None) + self.tasks.append(asyncio.create_task(_slow_task(), name=f"{job_id}")) + return AsyncJobGet(job_id=job_id) + + async def setup(self) -> None: + for m in (self.status, self.cancel, self.result): + await rpc_server.register_handler( + namespace, RPCMethodName(m.__name__), m + ) + await rpc_server.register_handler( + namespace, RPCMethodName(self.list_jobs.__name__), self.list_jobs + ) + + await rpc_server.register_handler(namespace, method_name, self.submit) + + fake_server = FakeServer() + await fake_server.setup() + + yield + + for task in fake_server.tasks: + await cancel_wait_task(task) + + +@pytest.mark.parametrize("method", ["result", "status", "cancel"]) +async def test_async_jobs_methods( + async_job_rpc_server: RabbitMQRPCClient, + rpc_client: RabbitMQRPCClient, + namespace: RPCNamespace, + job_id_data: AsyncJobNameData, + job_id: AsyncJobId, + method: str, +): + from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs + + async_jobs_method = getattr(async_jobs, method) + with pytest.raises(JobMissingError): + await async_jobs_method( + rpc_client, + rpc_namespace=namespace, + job_id=job_id, + job_id_data=job_id_data, + ) + + +async def test_list_jobs( + async_job_rpc_server: RabbitMQRPCClient, + rpc_client: RabbitMQRPCClient, + namespace: RPCNamespace, + method_name: RPCMethodName, + job_id_data: AsyncJobNameData, +): + await list_jobs( + rpc_client, + rpc_namespace=namespace, + filter_="", + job_id_data=job_id_data, + ) + + +async def test_submit( + async_job_rpc_server: RabbitMQRPCClient, + rpc_client: RabbitMQRPCClient, + namespace: RPCNamespace, + method_name: RPCMethodName, + job_id_data: AsyncJobNameData, +): + await submit( + rpc_client, + rpc_namespace=namespace, + method_name=method_name, + job_id_data=job_id_data, + ) + + +async def test_submit_with_invalid_method_name( + async_job_rpc_server: RabbitMQRPCClient, + rpc_client: RabbitMQRPCClient, + namespace: RPCNamespace, + job_id_data: AsyncJobNameData, +): + with pytest.raises(RemoteMethodNotRegisteredError): + await submit( + rpc_client, + rpc_namespace=namespace, + method_name=RPCMethodName("invalid_method_name"), + job_id_data=job_id_data, + ) + + +async def test_submit_and_wait_properly_timesout( + async_job_rpc_server: RabbitMQRPCClient, + rpc_client: RabbitMQRPCClient, + namespace: RPCNamespace, + method_name: RPCMethodName, + job_id_data: AsyncJobNameData, +): + with pytest.raises(TimeoutError): # noqa: PT012 + async for _job_composed_result in submit_and_wait( + rpc_client, + rpc_namespace=namespace, + method_name=method_name, + job_id_data=job_id_data, + client_timeout=datetime.timedelta(seconds=0.1), + ): + pass + + +async def test_submit_and_wait( + async_job_rpc_server: RabbitMQRPCClient, + rpc_client: RabbitMQRPCClient, + namespace: RPCNamespace, + method_name: RPCMethodName, + job_id_data: AsyncJobNameData, +): + async for job_composed_result in submit_and_wait( + rpc_client, + rpc_namespace=namespace, + method_name=method_name, + job_id_data=job_id_data, + client_timeout=datetime.timedelta(seconds=10), + ): + if not job_composed_result.done: + with pytest.raises(ValueError, match="No result ready!"): + await job_composed_result.result() + assert job_composed_result.done + assert job_composed_result.status.progress.actual_value == 1 + assert await job_composed_result.result() == AsyncJobResult( + result={ + "data": None, + "job_id": job_composed_result.status.job_id, + "job_id_data": job_id_data, + } + ) diff --git a/services/storage/src/simcore_service_storage/modules/celery/_celery_types.py b/services/storage/src/simcore_service_storage/modules/celery/_celery_types.py index e3742f6fc641..1ad45342a248 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/_celery_types.py +++ b/services/storage/src/simcore_service_storage/modules/celery/_celery_types.py @@ -1,6 +1,15 @@ +from functools import partial from pathlib import Path +from typing import Any from kombu.utils.json import register_type # type: ignore[import-untyped] +from models_library.api_schemas_storage.storage_schemas import ( + FileUploadCompletionBody, + FoldersBody, +) +from pydantic import BaseModel + +from ...models import FileMetaData def _path_encoder(obj): @@ -20,6 +29,24 @@ def _class_full_name(clz: type) -> str: return ".".join([clz.__module__, clz.__qualname__]) +def _encoder(obj: BaseModel, *args, **kwargs) -> dict[str, Any]: + return obj.model_dump(*args, **kwargs, mode="json") + + +def _decoder(clz: type[BaseModel], data: dict[str, Any]) -> BaseModel: + return clz(**data) + + +def _register_pydantic_types(*models: type[BaseModel]) -> None: + for model in models: + register_type( + model, + _class_full_name(model), + encoder=_encoder, + decoder=partial(_decoder, model), + ) + + def register_celery_types() -> None: register_type( Path, @@ -27,3 +54,6 @@ def register_celery_types() -> None: _path_encoder, _path_decoder, ) + _register_pydantic_types(FileUploadCompletionBody) + _register_pydantic_types(FileMetaData) + _register_pydantic_types(FoldersBody) diff --git a/services/storage/src/simcore_service_storage/modules/celery/_task.py b/services/storage/src/simcore_service_storage/modules/celery/_task.py index f89ca963c864..02d5c6a5edd6 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/_task.py +++ b/services/storage/src/simcore_service_storage/modules/celery/_task.py @@ -1,9 +1,10 @@ import asyncio +import inspect import logging import traceback from collections.abc import Callable, Coroutine from functools import wraps -from typing import Any, ParamSpec, TypeVar +from typing import Any, Concatenate, ParamSpec, TypeVar, overload from celery import ( # type: ignore[import-untyped] Celery, @@ -13,7 +14,7 @@ from celery.exceptions import Ignore # type: ignore[import-untyped] from . import get_event_loop -from .models import TaskError, TaskState +from .models import TaskError, TaskId, TaskState from .utils import get_fastapi_app _logger = logging.getLogger(__name__) @@ -55,13 +56,22 @@ def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any: def _async_task_wrapper( app: Celery, -) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, R]]: - def decorator(coro: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, R]: +) -> Callable[ + [Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]]], + Callable[Concatenate[Task, P], R], +]: + def decorator( + coro: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]], + ) -> Callable[Concatenate[Task, P], R]: @wraps(coro) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R: fastapi_app = get_fastapi_app(app) + _logger.debug("task id: %s", task.request.id) + # NOTE: task.request is a thread local object, so we need to pass the id explicitly + assert task.request.id is not None # nosec return asyncio.run_coroutine_threadsafe( - coro(*args, **kwargs), get_event_loop(fastapi_app) + coro(task, task.request.id, *args, **kwargs), + get_event_loop(fastapi_app), ).result() return wrapper @@ -69,11 +79,39 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return decorator -def define_task(app: Celery, fn: Callable, task_name: str | None = None): - wrapped_fn = error_handling(fn) +@overload +def define_task( + app: Celery, + fn: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]], + task_name: str | None = None, +) -> None: ... + + +@overload +def define_task( + app: Celery, + fn: Callable[Concatenate[Task, P], R], + task_name: str | None = None, +) -> None: ... + + +def define_task( # type: ignore[misc] + app: Celery, + fn: ( + Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]] + | Callable[Concatenate[Task, P], R] + ), + task_name: str | None = None, +) -> None: + """Decorator to define a celery task with error handling and abortable support""" + wrapped_fn: Callable[Concatenate[Task, P], R] if asyncio.iscoroutinefunction(fn): wrapped_fn = _async_task_wrapper(app)(fn) + else: + assert inspect.isfunction(fn) # nosec + wrapped_fn = fn + wrapped_fn = error_handling(wrapped_fn) app.task( name=task_name or fn.__name__, bind=True, diff --git a/services/storage/src/simcore_service_storage/modules/celery/models.py b/services/storage/src/simcore_service_storage/modules/celery/models.py index 2f04c5b81329..6b72a6e00198 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/models.py +++ b/services/storage/src/simcore_service_storage/modules/celery/models.py @@ -55,3 +55,6 @@ def _check_consistency(self) -> Self: class TaskError(BaseModel): exc_type: str exc_msg: str + + +TaskId: TypeAlias = str diff --git a/services/web/server/docker/boot.sh b/services/web/server/docker/boot.sh index a65425c1155f..6b42600e91a0 100755 --- a/services/web/server/docker/boot.sh +++ b/services/web/server/docker/boot.sh @@ -19,7 +19,7 @@ if [ "${SC_BUILD_TARGET}" = "development" ]; then command -v python | sed 's/^/ /' cd services/web/server - uv pip --quiet --no-cache-dir sync requirements/dev.txt + uv pip --quiet sync requirements/dev.txt cd - echo "$INFO" "PIP :" uv pip list