diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/playwright.py b/packages/pytest-simcore/src/pytest_simcore/helpers/playwright.py index 5ae58db0d877..0461e438ad81 100644 --- a/packages/pytest-simcore/src/pytest_simcore/helpers/playwright.py +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/playwright.py @@ -137,9 +137,7 @@ def on_framereceived(payload: str | bytes) -> None: ctx.logger.debug("⬆️ Frame received: %s", payload) def on_close(_: WebSocket) -> None: - ctx.logger.warning( - "⚠️ WebSocket closed. Attempting to reconnect..." - ) + ctx.logger.warning("⚠️ WebSocket closed. Attempting to reconnect...") self._attempt_reconnect(ctx.logger) def on_socketerror(error_msg: str) -> None: @@ -320,9 +318,9 @@ def __call__(self, message: str) -> bool: new_progress != self._current_progress[node_progress_event.progress_type] ): - self._current_progress[ - node_progress_event.progress_type - ] = new_progress + self._current_progress[node_progress_event.progress_type] = ( + new_progress + ) self.logger.info( "Current startup progress [expected number of node-progress-types=%d]: %s", @@ -343,29 +341,30 @@ def __call__(self, message: str) -> bool: url = ( f"https://{self.node_id}.services.{self.get_partial_product_url()}" ) - response = self.api_request_context.get(url, timeout=1000) - level = logging.DEBUG - if (response.status >= 400) and (response.status not in (502, 503)): - level = logging.ERROR - self.logger.log( - level, - "Querying service endpoint in case we missed some websocket messages. Url: %s Response: '%s' TIP: %s", - url, - f"{response.status}: {response.text()}", - ( - "We are emulating the frontend; a 5XX response is acceptable if the service is not yet ready." - ), - ) + with contextlib.suppress(PlaywrightTimeoutError): + response = self.api_request_context.get(url, timeout=1000) + level = logging.DEBUG + if (response.status >= 400) and (response.status not in (502, 503)): + level = logging.ERROR + self.logger.log( + level, + "Querying service endpoint in case we missed some websocket messages. Url: %s Response: '%s' TIP: %s", + url, + f"{response.status}: {response.text()}", + ( + "We are emulating the frontend; a 5XX response is acceptable if the service is not yet ready." + ), + ) - if response.status <= 400: - # NOTE: If the response status is less than 400, it means that the backend is ready (There are some services that respond with a 3XX) - if self.got_expected_node_progress_types(): - self.logger.warning( - "⚠️ Progress bar didn't receive 100 percent but service is already running: %s. TIP: we missed some websocket messages! ⚠️", # https://github.com/ITISFoundation/osparc-simcore/issues/6449 - self.get_current_progress(), - ) - return True - self._last_poll_timestamp = datetime.now(UTC) + if response.status <= 400: + # NOTE: If the response status is less than 400, it means that the backend is ready (There are some services that respond with a 3XX) + if self.got_expected_node_progress_types(): + self.logger.warning( + "⚠️ Progress bar didn't receive 100 percent but service is already running: %s. TIP: we missed some websocket messages! ⚠️", # https://github.com/ITISFoundation/osparc-simcore/issues/6449 + self.get_current_progress(), + ) + return True + self._last_poll_timestamp = datetime.now(UTC) return False @@ -511,19 +510,13 @@ def app_mode_trigger_next_app(page: Page) -> None: def wait_for_label_text( - page: Page, - locator: str, - substring: str, - timeout: int = 10000 + page: Page, locator: str, substring: str, timeout: int = 10000 ) -> Locator: - page.locator(locator).wait_for( - state="visible", - timeout=timeout - ) + page.locator(locator).wait_for(state="visible", timeout=timeout) page.wait_for_function( f"() => document.querySelector('{locator}').innerText.includes('{substring}')", - timeout=timeout + timeout=timeout, ) return page.locator(locator) diff --git a/packages/pytest-simcore/src/pytest_simcore/simcore_storage_data_models.py b/packages/pytest-simcore/src/pytest_simcore/simcore_storage_data_models.py index 4ca55f24bd6c..e897b9ced75e 100644 --- a/packages/pytest-simcore/src/pytest_simcore/simcore_storage_data_models.py +++ b/packages/pytest-simcore/src/pytest_simcore/simcore_storage_data_models.py @@ -18,7 +18,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine -from .helpers.faker_factories import random_project, random_user +from .helpers.faker_factories import DEFAULT_FAKER, random_project, random_user @asynccontextmanager @@ -62,7 +62,7 @@ async def other_user_id(sqlalchemy_async_engine: AsyncEngine) -> AsyncIterator[U @pytest.fixture async def create_project( user_id: UserID, sqlalchemy_async_engine: AsyncEngine -) -> AsyncIterator[Callable[[], Awaitable[dict[str, Any]]]]: +) -> AsyncIterator[Callable[..., Awaitable[dict[str, Any]]]]: created_project_uuids = [] async def _creator(**kwargs) -> dict[str, Any]: @@ -71,7 +71,7 @@ async def _creator(**kwargs) -> dict[str, Any]: async with sqlalchemy_async_engine.begin() as conn: result = await conn.execute( projects.insert() - .values(**random_project(**prj_config)) + .values(**random_project(DEFAULT_FAKER, **prj_config)) .returning(sa.literal_column("*")) ) row = result.one() diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py index c50799bda05b..ff51c59c4dbd 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py @@ -12,6 +12,7 @@ AsyncJobResult, AsyncJobStatus, ) +from models_library.api_schemas_rpc_async_jobs.exceptions import JobMissingError from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace from pydantic import NonNegativeInt, TypeAdapter from tenacity import ( @@ -20,6 +21,7 @@ before_sleep_log, retry, retry_if_exception_type, + stop_after_attempt, stop_after_delay, wait_fixed, wait_random_exponential, @@ -124,11 +126,11 @@ async def submit( _DEFAULT_RPC_RETRY_POLICY: dict[str, Any] = { - "retry": retry_if_exception_type(RemoteMethodNotRegisteredError), + "retry": retry_if_exception_type((RemoteMethodNotRegisteredError,)), "wait": wait_random_exponential(max=20), - "stop": stop_after_delay(60), + "stop": stop_after_attempt(30), "reraise": True, - "before_sleep": before_sleep_log(_logger, logging.INFO), + "before_sleep": before_sleep_log(_logger, logging.WARNING), } @@ -146,7 +148,7 @@ async def _wait_for_completion( async for attempt in AsyncRetrying( stop=stop_after_delay(client_timeout.total_seconds()), reraise=True, - retry=retry_if_exception_type(TryAgain), + retry=retry_if_exception_type((TryAgain, JobMissingError)), before_sleep=before_sleep_log(_logger, logging.DEBUG), wait=wait_fixed(_DEFAULT_POLL_INTERVAL_S), ): @@ -184,45 +186,72 @@ async def result(self) -> Any: return await self._result -async def submit_and_wait( +async def wait_and_get_result( rabbitmq_rpc_client: RabbitMQRPCClient, *, rpc_namespace: RPCNamespace, method_name: str, + job_id: AsyncJobId, job_id_data: AsyncJobNameData, client_timeout: datetime.timedelta, - **kwargs, ) -> AsyncGenerator[AsyncJobComposedResult, None]: - async_job_rpc_get = None + """when a job is already submitted this will wait for its completion + and return the composed result""" try: - async_job_rpc_get = await submit( - rabbitmq_rpc_client, - rpc_namespace=rpc_namespace, - method_name=method_name, - job_id_data=job_id_data, - **kwargs, - ) - job_status: AsyncJobStatus | None = None + job_status = None async for job_status in _wait_for_completion( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, method_name=method_name, - job_id=async_job_rpc_get.job_id, + job_id=job_id, job_id_data=job_id_data, client_timeout=client_timeout, ): assert job_status is not None # nosec yield AsyncJobComposedResult(job_status) + + # return the result if job_status: yield AsyncJobComposedResult( job_status, result( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, - job_id=async_job_rpc_get.job_id, + job_id=job_id, job_id_data=job_id_data, ), ) + except (TimeoutError, CancelledError) as error: + try: + await cancel( + rabbitmq_rpc_client, + rpc_namespace=rpc_namespace, + job_id=job_id, + job_id_data=job_id_data, + ) + except Exception as exc: + raise exc from error # NOSONAR + raise + + +async def submit_and_wait( + rabbitmq_rpc_client: RabbitMQRPCClient, + *, + rpc_namespace: RPCNamespace, + method_name: str, + job_id_data: AsyncJobNameData, + client_timeout: datetime.timedelta, + **kwargs, +) -> AsyncGenerator[AsyncJobComposedResult, None]: + async_job_rpc_get = None + try: + async_job_rpc_get = await submit( + rabbitmq_rpc_client, + rpc_namespace=rpc_namespace, + method_name=method_name, + job_id_data=job_id_data, + **kwargs, + ) except (TimeoutError, CancelledError) as error: if async_job_rpc_get is not None: try: @@ -235,3 +264,13 @@ async def submit_and_wait( except Exception as exc: raise exc from error raise + + async for wait_and_ in wait_and_get_result( + rabbitmq_rpc_client, + rpc_namespace=rpc_namespace, + method_name=method_name, + job_id=async_job_rpc_get.job_id, + job_id_data=job_id_data, + client_timeout=client_timeout, + ): + yield wait_and_ diff --git a/services/storage/src/simcore_service_storage/api/rest/dependencies/celery.py b/services/storage/src/simcore_service_storage/api/rest/dependencies/celery.py new file mode 100644 index 000000000000..e6d013a42b0f --- /dev/null +++ b/services/storage/src/simcore_service_storage/api/rest/dependencies/celery.py @@ -0,0 +1,13 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI +from servicelib.fastapi.dependencies import get_app + +from ....modules.celery import get_celery_client as _get_celery_client_from_app +from ....modules.celery.client import CeleryTaskQueueClient + + +def get_celery_client( + app: Annotated[FastAPI, Depends(get_app)], +) -> CeleryTaskQueueClient: + return _get_celery_client_from_app(app) diff --git a/services/storage/src/simcore_service_storage/api/rpc/_paths.py b/services/storage/src/simcore_service_storage/api/rpc/_paths.py index b91042cf43bd..e975191bc57a 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_paths.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_paths.py @@ -24,7 +24,6 @@ async def compute_path_size( location_id: LocationID, path: Path, ) -> AsyncJobGet: - assert app # nosec task_uuid = await get_celery_client(app).send_task( remote_compute_path_size.__name__, task_context=job_id_data.model_dump(), diff --git a/services/storage/src/simcore_service_storage/dsm.py b/services/storage/src/simcore_service_storage/dsm.py index 64d72a294047..2c2faa8fcd4b 100644 --- a/services/storage/src/simcore_service_storage/dsm.py +++ b/services/storage/src/simcore_service_storage/dsm.py @@ -13,7 +13,7 @@ def setup_dsm(app: FastAPI) -> None: async def _on_startup() -> None: - dsm_provider = DataManagerProvider(app) + dsm_provider = DataManagerProvider(app=app) dsm_provider.register_builder( SimcoreS3DataManager.get_location_id(), create_simcore_s3_data_manager, @@ -38,7 +38,7 @@ async def _on_shutdown() -> None: def get_dsm_provider(app: FastAPI) -> DataManagerProvider: - if not app.state.dsm_provider: + if not hasattr(app.state, "dsm_provider"): raise ConfigurationError( msg="DSM provider not available. Please check the configuration." ) diff --git a/services/storage/src/simcore_service_storage/main.py b/services/storage/src/simcore_service_storage/main.py index abf943386628..a37ead2cefc7 100644 --- a/services/storage/src/simcore_service_storage/main.py +++ b/services/storage/src/simcore_service_storage/main.py @@ -17,6 +17,4 @@ tracing_settings=_settings.STORAGE_TRACING, ) -_logger = logging.getLogger(__name__) - app = create_app(_settings) diff --git a/services/storage/src/simcore_service_storage/modules/celery/_common.py b/services/storage/src/simcore_service_storage/modules/celery/_common.py index b479408053fc..38e3235c7d95 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/_common.py +++ b/services/storage/src/simcore_service_storage/modules/celery/_common.py @@ -1,5 +1,6 @@ import logging import ssl +from typing import Any from celery import Celery # type: ignore[import-untyped] from settings_library.celery import CelerySettings @@ -8,24 +9,28 @@ _logger = logging.getLogger(__name__) +def _celery_configure(celery_settings: CelerySettings) -> dict[str, Any]: + base_config = { + "broker_connection_retry_on_startup": True, + "result_expires": celery_settings.CELERY_RESULT_EXPIRES, + "result_extended": True, + "result_serializer": "json", + "task_send_sent_event": True, + "task_track_started": True, + "worker_send_task_events": True, + } + if celery_settings.CELERY_REDIS_RESULT_BACKEND.REDIS_SECURE: + base_config["redis_backend_use_ssl"] = {"ssl_cert_reqs": ssl.CERT_NONE} + return base_config + + def create_app(celery_settings: CelerySettings) -> Celery: assert celery_settings - app = Celery( + return Celery( broker=celery_settings.CELERY_RABBIT_BROKER.dsn, backend=celery_settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn( RedisDatabase.CELERY_TASKS, ), + **_celery_configure(celery_settings), ) - app.conf.broker_connection_retry_on_startup = True - # NOTE: disable SSL cert validation (https://github.com/ITISFoundation/osparc-simcore/pull/7407) - if celery_settings.CELERY_REDIS_RESULT_BACKEND.REDIS_SECURE: - app.conf.redis_backend_use_ssl = {"ssl_cert_reqs": ssl.CERT_NONE} - app.conf.result_expires = celery_settings.CELERY_RESULT_EXPIRES - app.conf.result_extended = True # original args are included in the results - app.conf.result_serializer = "json" - app.conf.task_send_sent_event = True - app.conf.task_track_started = True - app.conf.worker_send_task_events = True # enable tasks monitoring - - return app diff --git a/services/storage/src/simcore_service_storage/modules/celery/worker_main.py b/services/storage/src/simcore_service_storage/modules/celery/worker_main.py index 99b9a53676ed..58febcb61f68 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/worker_main.py +++ b/services/storage/src/simcore_service_storage/modules/celery/worker_main.py @@ -23,7 +23,6 @@ tracing_settings=_settings.STORAGE_TRACING, ) -_logger = logging.getLogger(__name__) assert _settings.STORAGE_CELERY app = create_celery_app(_settings.STORAGE_CELERY) diff --git a/services/storage/tests/conftest.py b/services/storage/tests/conftest.py index 03c73f5bfd0d..cfd4917703a5 100644 --- a/services/storage/tests/conftest.py +++ b/services/storage/tests/conftest.py @@ -7,6 +7,7 @@ import asyncio +import datetime import logging import random import sys @@ -20,6 +21,9 @@ import simcore_service_storage from asgi_lifespan import LifespanManager from aws_library.s3 import SimcoreS3API +from celery import Celery +from celery.contrib.testing.worker import TestWorkController, start_worker +from celery.signals import worker_init, worker_shutdown from faker import Faker from fakeredis.aioredis import FakeRedis from fastapi import FastAPI @@ -56,15 +60,22 @@ ) from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.aiohttp import status +from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient from servicelib.utils import limited_gather from settings_library.rabbit import RabbitSettings from simcore_postgres_database.models.tokens import tokens from simcore_postgres_database.storage_models import file_meta_data, projects, users +from simcore_service_storage.api._worker_tasks.tasks import setup_worker_tasks from simcore_service_storage.core.application import create_app from simcore_service_storage.core.settings import ApplicationSettings from simcore_service_storage.datcore_dsm import DatCoreDataManager from simcore_service_storage.dsm import get_dsm_provider from simcore_service_storage.models import FileMetaData, FileMetaDataAtDB, S3BucketName +from simcore_service_storage.modules.celery.signals import ( + on_worker_init, + on_worker_shutdown, +) +from simcore_service_storage.modules.celery.worker import CeleryTaskQueueWorker from simcore_service_storage.modules.long_running_tasks import ( get_completed_upload_tasks, ) @@ -89,7 +100,6 @@ "pytest_simcore.environment_configs", "pytest_simcore.file_extra", "pytest_simcore.httpbin_service", - "pytest_simcore.minio_service", "pytest_simcore.openapi_specs", "pytest_simcore.postgres_service", "pytest_simcore.pytest_global_environs", @@ -188,6 +198,12 @@ def enabled_rabbitmq( return rabbit_service +@pytest.fixture +async def mocked_redis_server(mocker: MockerFixture) -> None: + mock_redis = FakeRedis() + mocker.patch("redis.asyncio.from_url", return_value=mock_redis) + + @pytest.fixture def app_settings( app_environment: EnvVarsDict, @@ -196,26 +212,22 @@ def app_settings( postgres_host_config: dict[str, str], mocked_s3_server_envs: EnvVarsDict, datcore_adapter_service_mock: respx.MockRouter, - mocked_redis_server, + mocked_redis_server: None, ) -> ApplicationSettings: test_app_settings = ApplicationSettings.create_from_envs() print(f"{test_app_settings.model_dump_json(indent=2)=}") return test_app_settings -@pytest.fixture -async def mocked_redis_server(mocker: MockerFixture) -> None: - mock_redis = FakeRedis() - mocker.patch("redis.asyncio.from_url", return_value=mock_redis) - - _LIFESPAN_TIMEOUT: Final[int] = 10 @pytest.fixture -async def initialized_app(app_settings: ApplicationSettings) -> AsyncIterator[FastAPI]: - settings = ApplicationSettings.create_from_envs() - app = create_app(settings) +async def initialized_app( + mock_celery_app: None, + app_settings: ApplicationSettings, +) -> AsyncIterator[FastAPI]: + app = create_app(app_settings) # NOTE: the timeout is sometime too small for CI machines, and even larger machines async with LifespanManager( app, startup_timeout=_LIFESPAN_TIMEOUT, shutdown_timeout=_LIFESPAN_TIMEOUT @@ -349,13 +361,13 @@ def upload_file( sqlalchemy_async_engine: AsyncEngine, storage_s3_client: SimcoreS3API, storage_s3_bucket: S3BucketName, - initialized_app: FastAPI, client: httpx.AsyncClient, project_id: ProjectID, node_id: NodeID, create_upload_file_link_v2: Callable[..., Awaitable[FileUploadSchema]], create_file_of_size: Callable[[ByteSize, str | None], Path], create_simcore_file_id: Callable[[ProjectID, NodeID, str], SimcoreS3FileID], + with_storage_celery_worker: CeleryTaskQueueWorker, ) -> Callable[ [ByteSize, str, SimcoreS3FileID | None], Awaitable[tuple[Path, SimcoreS3FileID]] ]: @@ -893,7 +905,9 @@ async def output_file( bucket=TypeAdapter(S3BucketName).validate_python("master-simcore"), location_id=SimcoreS3DataManager.get_location_id(), location_name=SimcoreS3DataManager.get_location_name(), - sha256_checksum=faker.sha256(), + sha256_checksum=TypeAdapter(SHA256Str).validate_python( + faker.sha256(raw_output=False) + ), ) file.entity_tag = "df9d868b94e53d18009066ca5cd90e9f" file.file_size = ByteSize(12) @@ -945,3 +959,100 @@ async def fake_datcore_tokens( await conn.execute( tokens.delete().where(tokens.c.token_id.in_(created_token_ids)) ) + + +@pytest.fixture(scope="session") +def celery_config() -> dict[str, Any]: + return { + "broker_connection_retry_on_startup": True, + "broker_url": "memory://localhost//", + "result_backend": "cache+memory://localhost//", + "result_expires": datetime.timedelta(days=7), + "result_extended": True, + "pool": "threads", + "worker_send_task_events": True, + "task_track_started": True, + "task_send_sent_event": True, + } + + +@pytest.fixture +def mock_celery_app(mocker: MockerFixture, celery_config: dict[str, Any]) -> Celery: + celery_app = Celery(**celery_config) + + for module in ( + "simcore_service_storage.modules.celery._common.create_app", + "simcore_service_storage.modules.celery.create_app", + ): + mocker.patch(module, return_value=celery_app) + + return celery_app + + +@pytest.fixture +def register_celery_tasks() -> Callable[[Celery], None]: + """override if tasks are needed""" + + def _(celery_app: Celery) -> None: ... + + return _ + + +@pytest.fixture +async def with_storage_celery_worker_controller( + app_environment: EnvVarsDict, + celery_app: Celery, + monkeypatch: pytest.MonkeyPatch, + register_celery_tasks: Callable[[Celery], None], +) -> AsyncIterator[TestWorkController]: + # Signals must be explicitily connected + worker_init.connect(on_worker_init) + worker_shutdown.connect(on_worker_shutdown) + + setup_worker_tasks(celery_app) + register_celery_tasks(celery_app) + + monkeypatch.setenv("STORAGE_WORKER_MODE", "true") + with start_worker( + celery_app, + pool="threads", + concurrency=1, + loglevel="info", + perform_ping_check=False, + worker_kwargs={"hostname": "celery@worker1"}, + ) as worker: + worker_init.send(sender=worker) + + # NOTE: wait for worker to be ready (sic) + await asyncio.sleep(1) + yield worker + + worker_shutdown.send(sender=worker) + + +@pytest.fixture +def with_storage_celery_worker( + with_storage_celery_worker_controller: TestWorkController, +) -> CeleryTaskQueueWorker: + assert isinstance(with_storage_celery_worker_controller.app, Celery) + return CeleryTaskQueueWorker(with_storage_celery_worker_controller.app) + + +@pytest.fixture +async def storage_rabbitmq_rpc_client( + rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]], +) -> RabbitMQRPCClient: + rpc_client = await rabbitmq_rpc_client("pytest_storage_rpc_client") + assert rpc_client + return rpc_client + + +@pytest.fixture +def product_name(faker: Faker) -> str: + return faker.name() + + +@pytest.fixture +def set_log_levels_for_noisy_libraries() -> None: + # Reduce the log level for 'werkzeug' + logging.getLogger("werkzeug").setLevel(logging.WARNING) diff --git a/services/storage/tests/unit/modules/celery/conftest.py b/services/storage/tests/unit/modules/celery/conftest.py deleted file mode 100644 index 8bbb621ef0bd..000000000000 --- a/services/storage/tests/unit/modules/celery/conftest.py +++ /dev/null @@ -1,102 +0,0 @@ -from collections.abc import Callable, Iterable -from datetime import timedelta -from typing import Any - -import pytest -from celery import Celery -from celery.contrib.testing.worker import TestWorkController, start_worker -from celery.signals import worker_init, worker_shutdown -from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict -from pytest_simcore.helpers.typing_env import EnvVarsDict -from simcore_service_storage.modules.celery.client import CeleryTaskQueueClient -from simcore_service_storage.modules.celery.signals import ( - on_worker_init, - on_worker_shutdown, -) -from simcore_service_storage.modules.celery.worker import CeleryTaskQueueWorker - - -@pytest.fixture -def app_environment( - monkeypatch: pytest.MonkeyPatch, - app_environment: EnvVarsDict, -) -> EnvVarsDict: - return setenvs_from_dict( - monkeypatch, - { - **app_environment, - "SC_BOOT_MODE": "local-development", - "RABBIT_HOST": "localhost", - "RABBIT_PORT": "5672", - "RABBIT_USER": "mock", - "RABBIT_SECURE": True, - "RABBIT_PASSWORD": "", - }, - ) - - -@pytest.fixture -def celery_conf() -> dict[str, Any]: - return { - "broker_url": "memory://", - "result_backend": "cache+memory://", - "result_expires": timedelta(days=7), - "result_extended": True, - "pool": "threads", - "worker_send_task_events": True, - "task_track_started": True, - "task_send_sent_event": True, - } - - -@pytest.fixture -def celery_app(celery_conf: dict[str, Any]): - return Celery(**celery_conf) - - -@pytest.fixture -def register_celery_tasks() -> Callable[[Celery], None]: - msg = "please define a callback that registers the tasks" - raise NotImplementedError(msg) - - -@pytest.fixture -def celery_client( - app_environment: EnvVarsDict, celery_app: Celery -) -> CeleryTaskQueueClient: - return CeleryTaskQueueClient(celery_app) - - -@pytest.fixture -def celery_worker_controller( - app_environment: EnvVarsDict, - register_celery_tasks: Callable[[Celery], None], - celery_app: Celery, -) -> Iterable[TestWorkController]: - - # Signals must be explicitily connected - worker_init.connect(on_worker_init) - worker_shutdown.connect(on_worker_shutdown) - - register_celery_tasks(celery_app) - - with start_worker( - celery_app, - pool="threads", - loglevel="info", - perform_ping_check=False, - worker_kwargs={"hostname": "celery@worker1"}, - ) as worker: - worker_init.send(sender=worker) - - yield worker - - worker_shutdown.send(sender=worker) - - -@pytest.fixture -def celery_worker( - celery_worker_controller: TestWorkController, -) -> CeleryTaskQueueWorker: - assert isinstance(celery_worker_controller.app, Celery) - return CeleryTaskQueueWorker(celery_worker_controller.app) diff --git a/services/storage/tests/unit/test__worker_tasks_paths.py b/services/storage/tests/unit/test__worker_tasks_paths.py deleted file mode 100644 index cca024d97d48..000000000000 --- a/services/storage/tests/unit/test__worker_tasks_paths.py +++ /dev/null @@ -1,228 +0,0 @@ -# pylint:disable=no-name-in-module -# pylint:disable=protected-access -# pylint:disable=redefined-outer-name -# pylint:disable=too-many-arguments -# pylint:disable=too-many-positional-arguments -# pylint:disable=unused-argument -# pylint:disable=unused-variable - - -import random -from pathlib import Path -from typing import Any, TypeAlias - -import httpx -import pytest -from celery import Celery, Task -from faker import Faker -from fastapi import FastAPI -from models_library.projects_nodes_io import LocationID, NodeID, SimcoreS3FileID -from models_library.users import UserID -from pydantic import ByteSize, TypeAdapter -from pytest_simcore.helpers.storage_utils import FileIDDict, ProjectWithFilesParams -from simcore_service_storage.api._worker_tasks._paths import compute_path_size -from simcore_service_storage.modules.celery.models import TaskId -from simcore_service_storage.modules.celery.utils import set_fastapi_app -from simcore_service_storage.simcore_s3_dsm import SimcoreS3DataManager - -pytest_simcore_core_services_selection = ["postgres"] -pytest_simcore_ops_services_selection = ["adminer"] - -_IsFile: TypeAlias = bool - - -def _filter_and_group_paths_one_level_deeper( - paths: list[Path], prefix: Path -) -> list[tuple[Path, _IsFile]]: - relative_paths = (path for path in paths if path.is_relative_to(prefix)) - return sorted( - { - ( - (path, len(path.relative_to(prefix).parts) == 1) - if len(path.relative_to(prefix).parts) == 1 - else (prefix / path.relative_to(prefix).parts[0], False) - ) - for path in relative_paths - }, - key=lambda x: x[0], - ) - - -async def _assert_compute_path_size( - *, - celery_task: Task, - task_id: TaskId, - location_id: LocationID, - user_id: UserID, - path: Path, - expected_total_size: int, -) -> ByteSize: - response = await compute_path_size( - celery_task, - task_id=task_id, - user_id=user_id, - location_id=location_id, - path=path, - ) - assert isinstance(response, ByteSize) - assert response == expected_total_size - return response - - -@pytest.fixture -def fake_celery_task(celery_app: Celery, initialized_app: FastAPI) -> Task: - celery_task = Task() - celery_task.app = celery_app - set_fastapi_app(celery_app, initialized_app) - return celery_task - - -@pytest.mark.parametrize( - "location_id", - [SimcoreS3DataManager.get_location_id()], - ids=[SimcoreS3DataManager.get_location_name()], - indirect=True, -) -@pytest.mark.parametrize( - "project_params", - [ - ProjectWithFilesParams( - num_nodes=5, - allowed_file_sizes=(TypeAdapter(ByteSize).validate_python("1b"),), - workspace_files_count=10, - ) - ], - ids=str, -) -async def test_path_compute_size( - fake_celery_task: Task, - location_id: LocationID, - user_id: UserID, - with_random_project_with_files: tuple[ - dict[str, Any], - dict[NodeID, dict[SimcoreS3FileID, FileIDDict]], - ], - project_params: ProjectWithFilesParams, -): - assert ( - len(project_params.allowed_file_sizes) == 1 - ), "test preconditions are not filled! allowed file sizes should have only 1 option for this test" - project, list_of_files = with_random_project_with_files - - total_num_files = sum( - len(files_in_node) for files_in_node in list_of_files.values() - ) - - # get size of a full project - expected_total_size = project_params.allowed_file_sizes[0] * total_num_files - path = Path(project["uuid"]) - await _assert_compute_path_size( - celery_task=fake_celery_task, - task_id=TaskId("fake_task"), - location_id=location_id, - user_id=user_id, - path=path, - expected_total_size=expected_total_size, - ) - - # get size of one of the nodes - selected_node_id = NodeID(random.choice(list(project["workbench"]))) # noqa: S311 - path = Path(project["uuid"]) / f"{selected_node_id}" - selected_node_s3_keys = [ - Path(s3_object_id) for s3_object_id in list_of_files[selected_node_id] - ] - expected_total_size = project_params.allowed_file_sizes[0] * len( - selected_node_s3_keys - ) - await _assert_compute_path_size( - celery_task=fake_celery_task, - task_id=TaskId("fake_task"), - location_id=location_id, - user_id=user_id, - path=path, - expected_total_size=expected_total_size, - ) - - # get size of the outputs of one of the nodes - path = Path(project["uuid"]) / f"{selected_node_id}" / "outputs" - selected_node_s3_keys = [ - Path(s3_object_id) - for s3_object_id in list_of_files[selected_node_id] - if s3_object_id.startswith(f"{path}") - ] - expected_total_size = project_params.allowed_file_sizes[0] * len( - selected_node_s3_keys - ) - await _assert_compute_path_size( - celery_task=fake_celery_task, - task_id=TaskId("fake_task"), - location_id=location_id, - user_id=user_id, - path=path, - expected_total_size=expected_total_size, - ) - - # get size of workspace in one of the nodes (this is semi-cached in the DB) - path = Path(project["uuid"]) / f"{selected_node_id}" / "workspace" - selected_node_s3_keys = [ - Path(s3_object_id) - for s3_object_id in list_of_files[selected_node_id] - if s3_object_id.startswith(f"{path}") - ] - expected_total_size = project_params.allowed_file_sizes[0] * len( - selected_node_s3_keys - ) - workspace_total_size = await _assert_compute_path_size( - celery_task=fake_celery_task, - task_id=TaskId("fake_task"), - location_id=location_id, - user_id=user_id, - path=path, - expected_total_size=expected_total_size, - ) - - # get size of folders inside the workspace - folders_inside_workspace = [ - p[0] - for p in _filter_and_group_paths_one_level_deeper(selected_node_s3_keys, path) - if p[1] is False - ] - accumulated_subfolder_size = 0 - for workspace_subfolder in folders_inside_workspace: - selected_node_s3_keys = [ - Path(s3_object_id) - for s3_object_id in list_of_files[selected_node_id] - if s3_object_id.startswith(f"{workspace_subfolder}") - ] - expected_total_size = project_params.allowed_file_sizes[0] * len( - selected_node_s3_keys - ) - accumulated_subfolder_size += await _assert_compute_path_size( - celery_task=fake_celery_task, - task_id=TaskId("fake_task"), - location_id=location_id, - user_id=user_id, - path=workspace_subfolder, - expected_total_size=expected_total_size, - ) - - assert workspace_total_size == accumulated_subfolder_size - - -async def test_path_compute_size_inexistent_path( - fake_celery_task: Task, - initialized_app: FastAPI, - client: httpx.AsyncClient, - location_id: LocationID, - user_id: UserID, - faker: Faker, - fake_datcore_tokens: tuple[str, str], -): - await _assert_compute_path_size( - celery_task=fake_celery_task, - task_id=TaskId("fake_task"), - location_id=location_id, - user_id=user_id, - path=Path(faker.file_path(absolute=False)), - expected_total_size=0, - ) diff --git a/services/storage/tests/unit/test_data_export.py b/services/storage/tests/unit/test_data_export.py index 6e4c2ca412a8..982ec1cabd1f 100644 --- a/services/storage/tests/unit/test_data_export.py +++ b/services/storage/tests/unit/test_data_export.py @@ -1,7 +1,6 @@ # pylint: disable=W0621 # pylint: disable=W0613 # pylint: disable=R6301 -from collections.abc import Awaitable, Callable from dataclasses import dataclass from pathlib import Path from typing import Any, Literal, NamedTuple @@ -34,15 +33,11 @@ from models_library.users import UserID from pydantic import ByteSize, TypeAdapter from pytest_mock import MockerFixture -from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.storage_utils import FileIDDict, ProjectWithFilesParams -from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.rabbitmq import RabbitMQRPCClient from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs from servicelib.rabbitmq.rpc_interfaces.storage.data_export import start_data_export -from settings_library.rabbit import RabbitSettings from simcore_service_storage.api.rpc._data_export import AccessRightError -from simcore_service_storage.core.settings import ApplicationSettings from simcore_service_storage.modules.celery.client import TaskUUID from simcore_service_storage.modules.celery.models import TaskState, TaskStatus from simcore_service_storage.simcore_s3_dsm import SimcoreS3DataManager @@ -122,38 +117,6 @@ async def mock_celery_client( return _celery_client -@pytest.fixture -async def app_environment( - app_environment: EnvVarsDict, - rabbit_service: RabbitSettings, - monkeypatch: pytest.MonkeyPatch, -): - new_envs = setenvs_from_dict( - monkeypatch, - { - **app_environment, - "RABBIT_HOST": rabbit_service.RABBIT_HOST, - "RABBIT_PORT": f"{rabbit_service.RABBIT_PORT}", - "RABBIT_USER": rabbit_service.RABBIT_USER, - "RABBIT_SECURE": f"{rabbit_service.RABBIT_SECURE}", - "RABBIT_PASSWORD": rabbit_service.RABBIT_PASSWORD.get_secret_value(), - }, - ) - - settings = ApplicationSettings.create_from_envs() - assert settings.STORAGE_RABBITMQ - - return new_envs - - -@pytest.fixture -async def rpc_client( - initialized_app: FastAPI, - rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]], -) -> RabbitMQRPCClient: - return await rabbitmq_rpc_client("client") - - class UserWithFile(NamedTuple): user: UserID file: Path @@ -195,7 +158,8 @@ class UserWithFile(NamedTuple): indirect=True, ) async def test_start_data_export_success( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: _MockCeleryClient, with_random_project_with_files: tuple[ dict[str, Any], @@ -223,7 +187,7 @@ async def test_start_data_export_success( pytest.fail(f"invalid parameter: {selection_type=}") result = await start_data_export( - rpc_client, + storage_rabbitmq_rpc_client, job_id_data=AsyncJobNameData(user_id=user_id, product_name="osparc"), data_export_start=DataExportTaskStartInput( location_id=0, @@ -258,7 +222,8 @@ async def test_start_data_export_success( indirect=True, ) async def test_start_data_export_scheduler_error( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: _MockCeleryClient, with_random_project_with_files: tuple[ dict[str, Any], @@ -266,7 +231,6 @@ async def test_start_data_export_scheduler_error( ], user_id: UserID, ): - _, list_of_files = with_random_project_with_files workspace_files = [ p for p in list(list_of_files.values())[0].keys() if "/workspace/" in p @@ -276,7 +240,7 @@ async def test_start_data_export_scheduler_error( with pytest.raises(JobSchedulerError): _ = await start_data_export( - rpc_client, + storage_rabbitmq_rpc_client, job_id_data=AsyncJobNameData(user_id=user_id, product_name="osparc"), data_export_start=DataExportTaskStartInput( location_id=0, @@ -293,14 +257,15 @@ async def test_start_data_export_scheduler_error( indirect=True, ) async def test_start_data_export_access_error( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: _MockCeleryClient, user_id: UserID, faker: Faker, ): with pytest.raises(AccessRightError): _ = await async_jobs.submit( - rpc_client, + storage_rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, method_name="start_data_export", job_id_data=AsyncJobNameData(user_id=user_id, product_name="osparc"), @@ -324,13 +289,14 @@ async def test_start_data_export_access_error( indirect=True, ) async def test_abort_data_export_success( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: _MockCeleryClient, ): assert mock_celery_client.get_task_uuids_object is not None assert not isinstance(mock_celery_client.get_task_uuids_object, Exception) await async_jobs.cancel( - rpc_client, + storage_rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id_data=AsyncJobNameData( user_id=_faker.pyint(min_value=1, max_value=100), product_name="osparc" @@ -353,7 +319,8 @@ async def test_abort_data_export_success( indirect=["mock_celery_client"], ) async def test_abort_data_export_error( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: _MockCeleryClient, expected_exception_type: type[Exception], ): @@ -363,7 +330,7 @@ async def test_abort_data_export_error( _job_id = next(iter(job_ids)) if len(job_ids) > 0 else AsyncJobId(_faker.uuid4()) with pytest.raises(expected_exception_type): await async_jobs.cancel( - rpc_client, + storage_rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id_data=AsyncJobNameData( user_id=_faker.pyint(min_value=1, max_value=100), product_name="osparc" @@ -395,7 +362,8 @@ async def test_abort_data_export_error( indirect=True, ) async def test_get_data_export_status( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: _MockCeleryClient, ): job_ids = mock_celery_client.get_task_uuids_object @@ -403,7 +371,7 @@ async def test_get_data_export_status( assert not isinstance(job_ids, Exception) _job_id = next(iter(job_ids)) if len(job_ids) > 0 else AsyncJobId(_faker.uuid4()) result = await async_jobs.status( - rpc_client, + storage_rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=_job_id, job_id_data=AsyncJobNameData( @@ -428,7 +396,8 @@ async def test_get_data_export_status( indirect=["mock_celery_client"], ) async def test_get_data_export_status_error( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: _MockCeleryClient, expected_exception_type: type[Exception], ): @@ -438,7 +407,7 @@ async def test_get_data_export_status_error( _job_id = next(iter(job_ids)) if len(job_ids) > 0 else AsyncJobId(_faker.uuid4()) with pytest.raises(expected_exception_type): _ = await async_jobs.status( - rpc_client, + storage_rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=_job_id, job_id_data=AsyncJobNameData( @@ -463,7 +432,8 @@ async def test_get_data_export_status_error( indirect=True, ) async def test_get_data_export_result_success( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: _MockCeleryClient, ): job_ids = mock_celery_client.get_task_uuids_object @@ -471,7 +441,7 @@ async def test_get_data_export_result_success( assert not isinstance(job_ids, Exception) _job_id = next(iter(job_ids)) if len(job_ids) > 0 else AsyncJobId(_faker.uuid4()) result = await async_jobs.result( - rpc_client, + storage_rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=_job_id, job_id_data=AsyncJobNameData( @@ -543,7 +513,8 @@ async def test_get_data_export_result_success( indirect=["mock_celery_client"], ) async def test_get_data_export_result_error( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: _MockCeleryClient, expected_exception: type[Exception], ): @@ -554,7 +525,7 @@ async def test_get_data_export_result_error( with pytest.raises(expected_exception): _ = await async_jobs.result( - rpc_client, + storage_rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=_job_id, job_id_data=AsyncJobNameData( @@ -571,11 +542,12 @@ async def test_get_data_export_result_error( indirect=True, ) async def test_list_jobs_success( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: MockerFixture, ): result = await async_jobs.list_jobs( - rpc_client, + storage_rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id_data=AsyncJobNameData( user_id=_faker.pyint(min_value=1, max_value=100), product_name="osparc" @@ -594,12 +566,13 @@ async def test_list_jobs_success( indirect=True, ) async def test_list_jobs_error( - rpc_client: RabbitMQRPCClient, + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, mock_celery_client: MockerFixture, ): with pytest.raises(JobSchedulerError): _ = await async_jobs.list_jobs( - rpc_client, + storage_rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id_data=AsyncJobNameData( user_id=_faker.pyint(min_value=1, max_value=100), product_name="osparc" diff --git a/services/storage/tests/unit/test_dsm.py b/services/storage/tests/unit/test_dsm.py deleted file mode 100644 index e99dfda19161..000000000000 --- a/services/storage/tests/unit/test_dsm.py +++ /dev/null @@ -1,44 +0,0 @@ -# pylint: disable=unused-variable -# pylint: disable=unused-argument -# pylint: disable=redefined-outer-name -# pylint: disable=protected-access - -from collections.abc import Awaitable, Callable -from pathlib import Path - -import pytest -from faker import Faker -from models_library.projects_nodes_io import SimcoreS3FileID -from models_library.users import UserID -from pydantic import ByteSize, TypeAdapter -from servicelib.utils import limited_gather -from simcore_service_storage.models import FileMetaData -from simcore_service_storage.simcore_s3_dsm import SimcoreS3DataManager - -pytest_simcore_core_services_selection = ["postgres"] -pytest_simcore_ops_services_selection = ["adminer"] - - -@pytest.fixture -async def dsm_mockup_complete_db( - simcore_s3_dsm: SimcoreS3DataManager, - user_id: UserID, - upload_file: Callable[ - [ByteSize, str, SimcoreS3FileID | None], - Awaitable[tuple[Path, SimcoreS3FileID]], - ], - cleanup_user_projects_file_metadata: None, - faker: Faker, -) -> tuple[FileMetaData, FileMetaData]: - file_size = TypeAdapter(ByteSize).validate_python("10Mib") - uploaded_files = await limited_gather( - *(upload_file(file_size, faker.file_name(), None) for _ in range(2)), - limit=2, - ) - fmds = await limited_gather( - *(simcore_s3_dsm.get_file(user_id, file_id) for _, file_id in uploaded_files), - limit=0, - ) - assert len(fmds) == 2 - - return (fmds[0], fmds[1]) diff --git a/services/storage/tests/unit/test_handlers_paths.py b/services/storage/tests/unit/test_handlers_paths.py index 31cb5c850617..6997bb5bf7dd 100644 --- a/services/storage/tests/unit/test_handlers_paths.py +++ b/services/storage/tests/unit/test_handlers_paths.py @@ -34,7 +34,7 @@ from simcore_service_storage.simcore_s3_dsm import SimcoreS3DataManager from sqlalchemy.ext.asyncio import AsyncEngine -pytest_simcore_core_services_selection = ["postgres"] +pytest_simcore_core_services_selection = ["postgres", "rabbit"] pytest_simcore_ops_services_selection = ["adminer"] _IsFile: TypeAlias = bool diff --git a/services/storage/tests/unit/test_handlers_simcore_s3.py b/services/storage/tests/unit/test_handlers_simcore_s3.py index d3768fd09eb7..ff43db81f48d 100644 --- a/services/storage/tests/unit/test_handlers_simcore_s3.py +++ b/services/storage/tests/unit/test_handlers_simcore_s3.py @@ -41,7 +41,6 @@ assert_file_meta_data_in_db, ) from pytest_simcore.helpers.storage_utils_project import clone_project_data -from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.aiohttp import status from servicelib.fastapi.long_running_tasks.client import long_running_task_request from settings_library.s3 import S3Settings @@ -540,7 +539,6 @@ async def test_create_and_delete_folders_from_project( @pytest.mark.parametrize("num_concurrent_calls", [50]) async def test_create_and_delete_folders_from_project_burst( set_log_levels_for_noisy_libraries: None, - minio_s3_settings_envs: EnvVarsDict, initialized_app: FastAPI, client: httpx.AsyncClient, user_id: UserID, diff --git a/services/storage/tests/unit/modules/celery/test_celery.py b/services/storage/tests/unit/test_modules_celery.py similarity index 86% rename from services/storage/tests/unit/modules/celery/test_celery.py rename to services/storage/tests/unit/test_modules_celery.py index c6c4e53135f6..6cbf63b05685 100644 --- a/services/storage/tests/unit/modules/celery/test_celery.py +++ b/services/storage/tests/unit/test_modules_celery.py @@ -1,3 +1,9 @@ +# pylint: disable=protected-access +# pylint: disable=redefined-outer-name +# pylint: disable=too-many-arguments +# pylint: disable=unused-argument +# pylint: disable=unused-variable + import asyncio import logging import time @@ -8,10 +14,11 @@ from celery import Celery, Task from celery.contrib.abortable import AbortableTask from common_library.errors_classes import OsparcErrorMixin +from fastapi import FastAPI from models_library.progress_bar import ProgressReport from pydantic import TypeAdapter, ValidationError from servicelib.logging_utils import log_context -from simcore_service_storage.modules.celery import get_event_loop +from simcore_service_storage.modules.celery import get_celery_client, get_event_loop from simcore_service_storage.modules.celery._task import define_task from simcore_service_storage.modules.celery.client import CeleryTaskQueueClient from simcore_service_storage.modules.celery.models import ( @@ -23,10 +30,21 @@ get_celery_worker, get_fastapi_app, ) +from simcore_service_storage.modules.celery.worker import CeleryTaskQueueWorker from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed _logger = logging.getLogger(__name__) +pytest_simcore_core_services_selection = ["postgres", "rabbit"] +pytest_simcore_ops_services_selection = [] + + +@pytest.fixture +def celery_client( + initialized_app: FastAPI, +) -> CeleryTaskQueueClient: + return get_celery_client(initialized_app) + async def _async_archive( celery_app: Celery, task_name: str, task_id: str, files: list[str] @@ -62,6 +80,7 @@ class MyError(OsparcErrorMixin, Exception): def failure_task(task: Task): + assert task msg = "BOOM!" raise MyError(msg=msg) @@ -73,7 +92,7 @@ def dreamer_task(task: AbortableTask) -> list[int]: _logger.warning("Alarm clock") return numbers numbers.append(randint(1, 90)) # noqa: S311 - time.sleep(1) + time.sleep(0.1) return numbers @@ -87,9 +106,9 @@ def _(celery_app: Celery) -> None: return _ -@pytest.mark.usefixtures("celery_worker") async def test_submitting_task_calling_async_function_results_with_success_state( celery_client: CeleryTaskQueueClient, + with_storage_celery_worker: CeleryTaskQueueWorker, ): task_context = TaskContext(user_id=42) @@ -116,9 +135,9 @@ async def test_submitting_task_calling_async_function_results_with_success_state ) == "archive.zip" -@pytest.mark.usefixtures("celery_worker") async def test_submitting_task_with_failure_results_with_error( celery_client: CeleryTaskQueueClient, + with_storage_celery_worker: CeleryTaskQueueWorker, ): task_context = TaskContext(user_id=42) @@ -142,9 +161,9 @@ async def test_submitting_task_with_failure_results_with_error( assert f"{result.exc_msg}" == "Something strange happened: BOOM!" -@pytest.mark.usefixtures("celery_worker") async def test_aborting_task_results_with_aborted_state( celery_client: CeleryTaskQueueClient, + with_storage_celery_worker: CeleryTaskQueueWorker, ): task_context = TaskContext(user_id=42) @@ -169,9 +188,9 @@ async def test_aborting_task_results_with_aborted_state( ).task_state == TaskState.ABORTED -@pytest.mark.usefixtures("celery_worker") async def test_listing_task_uuids_contains_submitted_task( celery_client: CeleryTaskQueueClient, + with_storage_celery_worker: CeleryTaskQueueWorker, ): task_context = TaskContext(user_id=42) @@ -182,7 +201,7 @@ async def test_listing_task_uuids_contains_submitted_task( for attempt in Retrying( retry=retry_if_exception_type(AssertionError), - wait=wait_fixed(1), + wait=wait_fixed(0.1), stop=stop_after_delay(10), ): with attempt: diff --git a/services/storage/tests/unit/test_rpc_handlers_paths.py b/services/storage/tests/unit/test_rpc_handlers_paths.py index ee6787b22a35..ef345c723e1c 100644 --- a/services/storage/tests/unit/test_rpc_handlers_paths.py +++ b/services/storage/tests/unit/test_rpc_handlers_paths.py @@ -7,42 +7,90 @@ # pylint:disable=unused-variable -from collections.abc import Awaitable, Callable +import asyncio +import datetime +import random from pathlib import Path -from unittest import mock +from typing import Any, TypeAlias import pytest from faker import Faker from fastapi import FastAPI -from models_library.projects_nodes_io import LocationID +from models_library.api_schemas_rpc_async_jobs.async_jobs import ( + AsyncJobNameData, + AsyncJobResult, +) +from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE +from models_library.products import ProductName +from models_library.projects_nodes_io import LocationID, NodeID, SimcoreS3FileID +from models_library.rabbitmq_basic_types import RPCMethodName from models_library.users import UserID -from pytest_mock import MockerFixture +from pydantic import ByteSize, TypeAdapter +from pytest_simcore.helpers.storage_utils import FileIDDict, ProjectWithFilesParams from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient +from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import ( + wait_and_get_result, +) from servicelib.rabbitmq.rpc_interfaces.storage.paths import compute_path_size +from simcore_service_storage.modules.celery.worker import CeleryTaskQueueWorker from simcore_service_storage.simcore_s3_dsm import SimcoreS3DataManager pytest_simcore_core_services_selection = ["postgres", "rabbit"] pytest_simcore_ops_services_selection = ["adminer"] +_IsFile: TypeAlias = bool -@pytest.fixture -async def storage_rabbitmq_rpc_client( - rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]], -) -> RabbitMQRPCClient: - rpc_client = await rabbitmq_rpc_client("pytest_storage_rpc_client") - assert rpc_client - return rpc_client +def _filter_and_group_paths_one_level_deeper( + paths: list[Path], prefix: Path +) -> list[tuple[Path, _IsFile]]: + relative_paths = (path for path in paths if path.is_relative_to(prefix)) + return sorted( + { + ( + (path, len(path.relative_to(prefix).parts) == 1) + if len(path.relative_to(prefix).parts) == 1 + else (prefix / path.relative_to(prefix).parts[0], False) + ) + for path in relative_paths + }, + key=lambda x: x[0], + ) -@pytest.fixture -async def mock_celery_send_task(mocker: MockerFixture, faker: Faker) -> mock.AsyncMock: - def mocked_send_task(*args, **kwargs): - return faker.uuid4() - return mocker.patch( - "simcore_service_storage.modules.celery.client.CeleryTaskQueueClient.send_task", - side_effect=mocked_send_task, +async def _assert_compute_path_size( + storage_rpc_client: RabbitMQRPCClient, + location_id: LocationID, + user_id: UserID, + product_name: ProductName, + *, + path: Path, + expected_total_size: int, +) -> ByteSize: + async_job, async_job_name = await compute_path_size( + storage_rpc_client, + product_name=product_name, + user_id=user_id, + location_id=location_id, + path=path, ) + await asyncio.sleep(1) + async for job_composed_result in wait_and_get_result( + storage_rpc_client, + rpc_namespace=STORAGE_RPC_NAMESPACE, + method_name=RPCMethodName(compute_path_size.__name__), + job_id=async_job.job_id, + job_id_data=AsyncJobNameData(user_id=user_id, product_name=product_name), + client_timeout=datetime.timedelta(seconds=120), + ): + if job_composed_result.done: + response = await job_composed_result.result() + assert isinstance(response, AsyncJobResult) + received_size = TypeAdapter(ByteSize).validate_python(response.result) + assert received_size == expected_total_size + return received_size + + pytest.fail("Job did not finish") @pytest.mark.parametrize( @@ -51,21 +99,150 @@ def mocked_send_task(*args, **kwargs): ids=[SimcoreS3DataManager.get_location_name()], indirect=True, ) -async def test_path_compute_size_calls_in_celery( +@pytest.mark.parametrize( + "project_params", + [ + ProjectWithFilesParams( + num_nodes=5, + allowed_file_sizes=(TypeAdapter(ByteSize).validate_python("1b"),), + workspace_files_count=10, + ) + ], + ids=str, +) +async def test_path_compute_size( + initialized_app: FastAPI, + storage_rabbitmq_rpc_client: RabbitMQRPCClient, + user_id: UserID, + location_id: LocationID, + with_random_project_with_files: tuple[ + dict[str, Any], + dict[NodeID, dict[SimcoreS3FileID, FileIDDict]], + ], + project_params: ProjectWithFilesParams, + product_name: ProductName, +): + assert ( + len(project_params.allowed_file_sizes) == 1 + ), "test preconditions are not filled! allowed file sizes should have only 1 option for this test" + project, list_of_files = with_random_project_with_files + + total_num_files = sum( + len(files_in_node) for files_in_node in list_of_files.values() + ) + + # get size of a full project + expected_total_size = project_params.allowed_file_sizes[0] * total_num_files + path = Path(project["uuid"]) + await _assert_compute_path_size( + storage_rabbitmq_rpc_client, + location_id, + user_id, + path=path, + expected_total_size=expected_total_size, + product_name=product_name, + ) + + # get size of one of the nodes + selected_node_id = NodeID(random.choice(list(project["workbench"]))) # noqa: S311 + path = Path(project["uuid"]) / f"{selected_node_id}" + selected_node_s3_keys = [ + Path(s3_object_id) for s3_object_id in list_of_files[selected_node_id] + ] + expected_total_size = project_params.allowed_file_sizes[0] * len( + selected_node_s3_keys + ) + await _assert_compute_path_size( + storage_rabbitmq_rpc_client, + location_id, + user_id, + path=path, + expected_total_size=expected_total_size, + product_name=product_name, + ) + + # get size of the outputs of one of the nodes + path = Path(project["uuid"]) / f"{selected_node_id}" / "outputs" + selected_node_s3_keys = [ + Path(s3_object_id) + for s3_object_id in list_of_files[selected_node_id] + if s3_object_id.startswith(f"{path}") + ] + expected_total_size = project_params.allowed_file_sizes[0] * len( + selected_node_s3_keys + ) + await _assert_compute_path_size( + storage_rabbitmq_rpc_client, + location_id, + user_id, + path=path, + expected_total_size=expected_total_size, + product_name=product_name, + ) + + # get size of workspace in one of the nodes (this is semi-cached in the DB) + path = Path(project["uuid"]) / f"{selected_node_id}" / "workspace" + selected_node_s3_keys = [ + Path(s3_object_id) + for s3_object_id in list_of_files[selected_node_id] + if s3_object_id.startswith(f"{path}") + ] + expected_total_size = project_params.allowed_file_sizes[0] * len( + selected_node_s3_keys + ) + workspace_total_size = await _assert_compute_path_size( + storage_rabbitmq_rpc_client, + location_id, + user_id, + path=path, + expected_total_size=expected_total_size, + product_name=product_name, + ) + + # get size of folders inside the workspace + folders_inside_workspace = [ + p[0] + for p in _filter_and_group_paths_one_level_deeper(selected_node_s3_keys, path) + if p[1] is False + ] + accumulated_subfolder_size = 0 + for workspace_subfolder in folders_inside_workspace: + selected_node_s3_keys = [ + Path(s3_object_id) + for s3_object_id in list_of_files[selected_node_id] + if s3_object_id.startswith(f"{workspace_subfolder}") + ] + expected_total_size = project_params.allowed_file_sizes[0] * len( + selected_node_s3_keys + ) + accumulated_subfolder_size += await _assert_compute_path_size( + storage_rabbitmq_rpc_client, + location_id, + user_id, + path=workspace_subfolder, + expected_total_size=expected_total_size, + product_name=product_name, + ) + + assert workspace_total_size == accumulated_subfolder_size + + +async def test_path_compute_size_inexistent_path( + mock_celery_app: None, initialized_app: FastAPI, storage_rabbitmq_rpc_client: RabbitMQRPCClient, + with_storage_celery_worker: CeleryTaskQueueWorker, location_id: LocationID, user_id: UserID, faker: Faker, - mock_celery_send_task: mock.AsyncMock, + fake_datcore_tokens: tuple[str, str], + product_name: ProductName, ): - received, job_id_data = await compute_path_size( + await _assert_compute_path_size( storage_rabbitmq_rpc_client, - user_id=user_id, - product_name=faker.name(), - location_id=location_id, + location_id, + user_id, path=Path(faker.file_path(absolute=False)), + expected_total_size=0, + product_name=product_name, ) - mock_celery_send_task.assert_called_once() - assert received - assert job_id_data