diff --git a/packages/celery-library/src/celery_library/common.py b/packages/celery-library/src/celery_library/app.py similarity index 100% rename from packages/celery-library/src/celery_library/common.py rename to packages/celery-library/src/celery_library/app.py diff --git a/packages/celery-library/src/celery_library/signals.py b/packages/celery-library/src/celery_library/signals.py deleted file mode 100644 index 02f1a56f0ec2..000000000000 --- a/packages/celery-library/src/celery_library/signals.py +++ /dev/null @@ -1,52 +0,0 @@ -import asyncio -import logging -import threading - -from celery import Celery # type: ignore[import-untyped] -from celery.worker.worker import WorkController # type: ignore[import-untyped] -from servicelib.celery.app_server import BaseAppServer -from servicelib.logging_utils import log_context - -from .utils import get_app_server, set_app_server - -_logger = logging.getLogger(__name__) - - -def on_worker_init( - sender: WorkController, - app_server: BaseAppServer, - **_kwargs, -) -> None: - startup_complete_event = threading.Event() - - def _init(startup_complete_event: threading.Event) -> None: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - assert sender.app # nosec - assert isinstance(sender.app, Celery) # nosec - - set_app_server(sender.app, app_server) - - app_server.event_loop = loop - - loop.run_until_complete(app_server.run_until_shutdown(startup_complete_event)) - - thread = threading.Thread( - group=None, - target=_init, - name="app_server_init", - args=(startup_complete_event,), - daemon=True, - ) - thread.start() - - startup_complete_event.wait() - - -def on_worker_shutdown(sender, **_kwargs) -> None: - with log_context(_logger, logging.INFO, "Worker shutdown"): - assert isinstance(sender.app, Celery) - app_server = get_app_server(sender.app) - - app_server.shutdown_event.set() diff --git a/packages/celery-library/src/celery_library/task.py b/packages/celery-library/src/celery_library/task.py index f01bf301347c..dfcc9f889a0b 100644 --- a/packages/celery-library/src/celery_library/task.py +++ b/packages/celery-library/src/celery_library/task.py @@ -13,7 +13,7 @@ from servicelib.celery.models import TaskKey from .errors import encode_celery_transferrable_error -from .utils import get_app_server +from .worker.app_server import get_app_server _logger = logging.getLogger(__name__) diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/__init__.py b/packages/celery-library/src/celery_library/worker/__init__.py similarity index 100% rename from services/api-server/src/simcore_service_api_server/celery_worker/__init__.py rename to packages/celery-library/src/celery_library/worker/__init__.py diff --git a/packages/celery-library/src/celery_library/worker/app.py b/packages/celery-library/src/celery_library/worker/app.py new file mode 100644 index 000000000000..9c243f0e3961 --- /dev/null +++ b/packages/celery-library/src/celery_library/worker/app.py @@ -0,0 +1,20 @@ +from collections.abc import Callable + +from celery import Celery # type: ignore[import-untyped] +from servicelib.celery.app_server import BaseAppServer +from settings_library.celery import CelerySettings + +from ..app import create_app +from .signals import register_worker_signals + + +def create_worker_app( + settings: CelerySettings, + register_worker_tasks_cb: Callable[[Celery], None], + app_server_factory_cb: Callable[[], BaseAppServer], +) -> Celery: + app = create_app(settings) + register_worker_tasks_cb(app) + register_worker_signals(app, settings, app_server_factory_cb) + + return app diff --git a/packages/celery-library/src/celery_library/utils.py b/packages/celery-library/src/celery_library/worker/app_server.py similarity index 100% rename from packages/celery-library/src/celery_library/utils.py rename to packages/celery-library/src/celery_library/worker/app_server.py diff --git a/packages/celery-library/src/celery_library/worker/signals.py b/packages/celery-library/src/celery_library/worker/signals.py new file mode 100644 index 000000000000..042122d9586f --- /dev/null +++ b/packages/celery-library/src/celery_library/worker/signals.py @@ -0,0 +1,71 @@ +import asyncio +import threading +from collections.abc import Callable + +from celery import Celery # type: ignore[import-untyped] +from celery.signals import ( # type: ignore[import-untyped] + worker_init, + worker_process_init, + worker_process_shutdown, + worker_shutdown, +) +from servicelib.celery.app_server import BaseAppServer +from settings_library.celery import CeleryPoolType, CelerySettings + +from .app_server import get_app_server, set_app_server + + +def _worker_init_wrapper( + app: Celery, app_server_factory: Callable[[], BaseAppServer] +) -> Callable[..., None]: + def _worker_init_handler(**_kwargs) -> None: + startup_complete_event = threading.Event() + + def _init(startup_complete_event: threading.Event) -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + app_server = app_server_factory() + app_server.event_loop = loop + + set_app_server(app, app_server) + + loop.run_until_complete( + app_server.run_until_shutdown(startup_complete_event) + ) + + thread = threading.Thread( + group=None, + target=_init, + name="app_server_init", + args=(startup_complete_event,), + daemon=True, + ) + thread.start() + + startup_complete_event.wait() + + return _worker_init_handler + + +def _worker_shutdown_wrapper(app: Celery) -> Callable[..., None]: + def _worker_shutdown_handler(**_kwargs) -> None: + get_app_server(app).shutdown_event.set() + + return _worker_shutdown_handler + + +def register_worker_signals( + app: Celery, + settings: CelerySettings, + app_server_factory: Callable[[], BaseAppServer], +) -> None: + match settings.CELERY_POOL: + case CeleryPoolType.PREFORK: + worker_process_init.connect( + _worker_init_wrapper(app, app_server_factory), weak=False + ) + worker_process_shutdown.connect(_worker_shutdown_wrapper(app), weak=False) + case _: + worker_init.connect(_worker_init_wrapper(app, app_server_factory), weak=False) + worker_shutdown.connect(_worker_shutdown_wrapper(app), weak=False) diff --git a/packages/celery-library/tests/conftest.py b/packages/celery-library/tests/conftest.py index e37f7d003f1e..9fd16d217cc2 100644 --- a/packages/celery-library/tests/conftest.py +++ b/packages/celery-library/tests/conftest.py @@ -14,17 +14,16 @@ start_worker, ) from celery.signals import worker_init, worker_shutdown -from celery.worker.worker import WorkController from celery_library.backends.redis import RedisTaskStore -from celery_library.signals import on_worker_init, on_worker_shutdown from celery_library.task_manager import CeleryTaskManager from celery_library.types import register_celery_types +from celery_library.worker.signals import _worker_init_wrapper, _worker_shutdown_wrapper from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.celery.app_server import BaseAppServer from servicelib.celery.task_manager import TaskManager from servicelib.redis import RedisClientSDK -from settings_library.celery import CelerySettings +from settings_library.celery import CeleryPoolType, CelerySettings from settings_library.redis import RedisDatabase, RedisSettings pytest_plugins = [ @@ -104,11 +103,6 @@ def celery_settings( return CelerySettings.create_from_envs() -@pytest.fixture -def app_server(celery_app: Celery, celery_settings: CelerySettings) -> BaseAppServer: - return FakeAppServer(app=celery_app, settings=celery_settings) - - @pytest.fixture(scope="session") def celery_config() -> dict[str, Any]: return { @@ -128,21 +122,25 @@ def celery_config() -> dict[str, Any]: @pytest.fixture async def with_celery_worker( celery_app: Celery, - app_server: BaseAppServer, + celery_settings: CelerySettings, register_celery_tasks: Callable[[Celery], None], ) -> AsyncIterator[TestWorkController]: - def _on_worker_init_wrapper(sender: WorkController, **_kwargs): - return on_worker_init(sender, app_server, **_kwargs) - worker_init.connect(_on_worker_init_wrapper) - worker_shutdown.connect(on_worker_shutdown) + def _app_server_factory() -> BaseAppServer: + return FakeAppServer(app=celery_app, settings=celery_settings) + + # NOTE: explicitly connect the signals in tests + worker_init.connect( + _worker_init_wrapper(celery_app, _app_server_factory), weak=False + ) + worker_shutdown.connect(_worker_shutdown_wrapper(celery_app), weak=False) register_celery_tasks(celery_app) with start_worker( celery_app, concurrency=1, - pool="threads", + pool=CeleryPoolType.THREADS, loglevel="info", perform_ping_check=False, queues="default", diff --git a/packages/celery-library/tests/unit/test_task_manager.py b/packages/celery-library/tests/unit/test_task_manager.py index 040c0541ed53..8a79adf9d7a1 100644 --- a/packages/celery-library/tests/unit/test_task_manager.py +++ b/packages/celery-library/tests/unit/test_task_manager.py @@ -16,7 +16,7 @@ from celery_library.errors import TaskNotFoundError, TransferrableCeleryError from celery_library.task import register_task from celery_library.task_manager import CeleryTaskManager -from celery_library.utils import get_app_server +from celery_library.worker.app_server import get_app_server from common_library.errors_classes import OsparcErrorMixin from faker import Faker from models_library.progress_bar import ProgressReport diff --git a/packages/service-library/src/servicelib/fastapi/celery/app_server.py b/packages/service-library/src/servicelib/fastapi/celery/app_server.py index 3c42aa9144d0..afc47d050f8c 100644 --- a/packages/service-library/src/servicelib/fastapi/celery/app_server.py +++ b/packages/service-library/src/servicelib/fastapi/celery/app_server.py @@ -9,6 +9,7 @@ from ...celery.app_server import BaseAppServer from ...celery.task_manager import TaskManager +_STARTUP_TIMEOUT: Final[float] = datetime.timedelta(minutes=5).total_seconds() _SHUTDOWN_TIMEOUT: Final[float] = datetime.timedelta(seconds=10).total_seconds() _logger = logging.getLogger(__name__) @@ -27,9 +28,10 @@ async def run_until_shutdown( ) -> None: async with LifespanManager( self.app, - startup_timeout=None, # waits for full app initialization (DB migrations, etc.) + startup_timeout=_STARTUP_TIMEOUT, shutdown_timeout=_SHUTDOWN_TIMEOUT, ): - _logger.info("fastapi app initialized") + _logger.info("FastAPI initialized: %s", self.app) startup_completed_event.set() await self.shutdown_event.wait() # NOTE: wait here until shutdown is requested + _logger.info("FastAPI shutdown completed: %s", self.app) diff --git a/packages/settings-library/src/settings_library/celery.py b/packages/settings-library/src/settings_library/celery.py index 168b86c6745c..d6c9e837d157 100644 --- a/packages/settings-library/src/settings_library/celery.py +++ b/packages/settings-library/src/settings_library/celery.py @@ -1,4 +1,5 @@ from datetime import timedelta +from enum import StrEnum from typing import Annotated from pydantic import Field @@ -9,6 +10,13 @@ from .base import BaseCustomSettings +class CeleryPoolType(StrEnum): + PREFORK = "prefork" + EVENTLET = "eventlet" + GEVENT = "gevent" + THREADS = "threads" + + class CelerySettings(BaseCustomSettings): CELERY_RABBIT_BROKER: Annotated[ RabbitSettings, Field(json_schema_extra={"auto_default_from_env": True}) @@ -35,6 +43,13 @@ class CelerySettings(BaseCustomSettings): ), ] = True + CELERY_POOL: Annotated[ + CeleryPoolType, + Field( + description="Type of pool to use. One of: prefork, eventlet, gevent, threads. See https://docs.celeryq.dev/en/stable/userguide/concurrency/index.html for details.", + ), + ] = CeleryPoolType.PREFORK + model_config = SettingsConfigDict( json_schema_extra={ "examples": [ diff --git a/services/api-server/docker/boot.sh b/services/api-server/docker/boot.sh index 227be9c56b96..ef5bf4fd42e2 100755 --- a/services/api-server/docker/boot.sh +++ b/services/api-server/docker/boot.sh @@ -48,18 +48,16 @@ if [ "${API_SERVER_WORKER_MODE}" = "true" ]; then --recursive \ -- \ celery \ - --app=boot_celery_worker:app \ - --workdir=services/api-server/docker \ - worker --pool=threads \ + --app=simcore_service_api_server.modules.celery.worker.main:app \ + worker --pool="${CELERY_POOL}" \ --loglevel="${API_SERVER_LOGLEVEL}" \ --concurrency="${CELERY_CONCURRENCY}" \ --hostname="${API_SERVER_WORKER_NAME}" \ --queues="${CELERY_QUEUES:-default}" else exec celery \ - --app=boot_celery_worker:app \ - --workdir=services/api-server/docker \ - worker --pool=threads \ + --app=simcore_service_api_server.modules.celery.worker.main:app \ + worker --pool="${CELERY_POOL}" \ --loglevel="${API_SERVER_LOGLEVEL}" \ --concurrency="${CELERY_CONCURRENCY}" \ --hostname="${API_SERVER_WORKER_NAME}" \ diff --git a/services/api-server/docker/boot_celery_worker.py b/services/api-server/docker/boot_celery_worker.py deleted file mode 100644 index e0c7e119ced8..000000000000 --- a/services/api-server/docker/boot_celery_worker.py +++ /dev/null @@ -1,13 +0,0 @@ -from celery.signals import worker_init, worker_shutdown # type: ignore[import-untyped] -from celery_library.signals import ( - on_worker_shutdown, -) -from simcore_service_api_server.celery_worker.worker_main import ( - get_app, - worker_init_wrapper, -) - -app = get_app() - -worker_init.connect(worker_init_wrapper) -worker_shutdown.connect(on_worker_shutdown) diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_main.py b/services/api-server/src/simcore_service_api_server/celery_worker/worker_main.py deleted file mode 100644 index c9e99cda269f..000000000000 --- a/services/api-server/src/simcore_service_api_server/celery_worker/worker_main.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Main application to be deployed in for example uvicorn.""" - -from functools import partial - -from celery_library.common import create_app as create_celery_app -from celery_library.signals import ( - on_worker_init, -) -from servicelib.fastapi.celery.app_server import FastAPIAppServer -from servicelib.logging_utils import setup_loggers -from servicelib.tracing import TracingConfig - -from ..core.application import create_app -from ..core.settings import ApplicationSettings -from .worker_tasks.tasks import setup_worker_tasks - - -def get_app(): - _settings = ApplicationSettings.create_from_envs() - _tracing_settings = _settings.API_SERVER_TRACING - _tracing_config = TracingConfig.create( - tracing_settings=_tracing_settings, - service_name="api-server-celery-worker", - ) - setup_loggers( - log_format_local_dev_enabled=_settings.API_SERVER_LOG_FORMAT_LOCAL_DEV_ENABLED, - logger_filter_mapping=_settings.API_SERVER_LOG_FILTER_MAPPING, - tracing_config=_tracing_config, - log_base_level=_settings.log_level, - noisy_loggers=None, - ) - - assert _settings.API_SERVER_CELERY # nosec - app = create_celery_app(_settings.API_SERVER_CELERY) - setup_worker_tasks(app) - - return app - - -def worker_init_wrapper(sender, **_kwargs): - _settings = ApplicationSettings.create_from_envs() - assert _settings.API_SERVER_CELERY # nosec - app_server = FastAPIAppServer(app=create_app(_settings)) - - return partial(on_worker_init, app_server=app_server)(sender, **_kwargs) diff --git a/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py b/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py index 13829cfdd303..fcfbbb35678f 100644 --- a/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py +++ b/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py @@ -1,7 +1,7 @@ import logging +from celery_library.app import create_app from celery_library.backends.redis import RedisTaskStore -from celery_library.common import create_app from celery_library.task_manager import CeleryTaskManager from celery_library.types import register_celery_types, register_pydantic_types from fastapi import FastAPI diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/__init__.py b/services/api-server/src/simcore_service_api_server/modules/__init__.py similarity index 100% rename from services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/__init__.py rename to services/api-server/src/simcore_service_api_server/modules/__init__.py diff --git a/services/api-server/src/simcore_service_api_server/modules/celery/__init__.py b/services/api-server/src/simcore_service_api_server/modules/celery/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/services/api-server/src/simcore_service_api_server/modules/celery/worker/__init__.py b/services/api-server/src/simcore_service_api_server/modules/celery/worker/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py b/services/api-server/src/simcore_service_api_server/modules/celery/worker/_functions_tasks.py similarity index 86% rename from services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py rename to services/api-server/src/simcore_service_api_server/modules/celery/worker/_functions_tasks.py index e0687f6ab89a..bfe4eb8e63cb 100644 --- a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py +++ b/services/api-server/src/simcore_service_api_server/modules/celery/worker/_functions_tasks.py @@ -1,7 +1,7 @@ from celery import ( # type: ignore[import-untyped] # pylint: disable=no-name-in-module Task, ) -from celery_library.utils import get_app_server # pylint: disable=no-name-in-module +from celery_library.worker.app_server import get_app_server from fastapi import FastAPI from models_library.functions import RegisteredFunction, RegisteredFunctionJob from models_library.projects import ProjectID @@ -9,9 +9,9 @@ from servicelib.celery.models import TaskKey from simcore_service_api_server._service_function_jobs import FunctionJobService -from ...api.dependencies.authentication import Identity -from ...api.dependencies.rabbitmq import get_rabbitmq_rpc_client -from ...api.dependencies.services import ( +from ....api.dependencies.authentication import Identity +from ....api.dependencies.rabbitmq import get_rabbitmq_rpc_client +from ....api.dependencies.services import ( get_catalog_service, get_directorv2_service, get_function_job_service, @@ -20,13 +20,16 @@ get_solver_service, get_storage_service, ) -from ...api.dependencies.webserver_http import get_session_cookie, get_webserver_session -from ...api.dependencies.webserver_rpc import get_wb_api_rpc_client -from ...models.api_resources import JobLinks -from ...models.domain.functions import PreRegisteredFunctionJobData -from ...models.schemas.jobs import JobPricingSpecification -from ...services_http.director_v2 import DirectorV2Api -from ...services_http.storage import StorageApi +from ....api.dependencies.webserver_http import ( + get_session_cookie, + get_webserver_session, +) +from ....api.dependencies.webserver_rpc import get_wb_api_rpc_client +from ....models.api_resources import JobLinks +from ....models.domain.functions import PreRegisteredFunctionJobData +from ....models.schemas.jobs import JobPricingSpecification +from ....services_http.director_v2 import DirectorV2Api +from ....services_http.storage import StorageApi async def _assemble_function_job_service( diff --git a/services/api-server/src/simcore_service_api_server/modules/celery/worker/main.py b/services/api-server/src/simcore_service_api_server/modules/celery/worker/main.py new file mode 100644 index 000000000000..826802f4d34b --- /dev/null +++ b/services/api-server/src/simcore_service_api_server/modules/celery/worker/main.py @@ -0,0 +1,39 @@ +from celery_library.worker.app import create_worker_app +from servicelib.fastapi.celery.app_server import FastAPIAppServer +from servicelib.logging_utils import setup_loggers +from servicelib.tracing import TracingConfig + +from ....core.application import create_app +from ....core.settings import ApplicationSettings +from .tasks import register_worker_tasks + +_settings = ApplicationSettings.create_from_envs() +_tracing_settings = _settings.API_SERVER_TRACING +_tracing_config = TracingConfig.create( + tracing_settings=_tracing_settings, + service_name="api-server-celery-worker", +) + + +def get_app(): + setup_loggers( + log_format_local_dev_enabled=_settings.API_SERVER_LOG_FORMAT_LOCAL_DEV_ENABLED, + logger_filter_mapping=_settings.API_SERVER_LOG_FILTER_MAPPING, + tracing_config=_tracing_config, + log_base_level=_settings.log_level, + noisy_loggers=None, + ) + + def _app_server_factory() -> FastAPIAppServer: + fastapi_app = create_app(_settings, tracing_config=_tracing_config) + return FastAPIAppServer(app=fastapi_app) + + assert _settings.API_SERVER_CELERY # nosec + return create_worker_app( + _settings.API_SERVER_CELERY, + register_worker_tasks, + _app_server_factory, + ) + + +app = get_app() diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py b/services/api-server/src/simcore_service_api_server/modules/celery/worker/tasks.py similarity index 76% rename from services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py rename to services/api-server/src/simcore_service_api_server/modules/celery/worker/tasks.py index cef6ad06d18c..b567d1cc63dc 100644 --- a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py +++ b/services/api-server/src/simcore_service_api_server/modules/celery/worker/tasks.py @@ -7,13 +7,13 @@ from celery_library.types import register_celery_types, register_pydantic_types from servicelib.logging_utils import log_context -from ...models.domain.celery_models import pydantic_types_to_register -from .functions_tasks import run_function +from ....models.domain.celery_models import pydantic_types_to_register +from ._functions_tasks import run_function _logger = logging.getLogger(__name__) -def setup_worker_tasks(app: Celery) -> None: +def register_worker_tasks(app: Celery) -> None: register_celery_types() register_pydantic_types(*pydantic_types_to_register) diff --git a/services/api-server/tests/unit/api_functions/celery/conftest.py b/services/api-server/tests/unit/api_functions/celery/conftest.py index 75c3a32123a9..5774596c91d2 100644 --- a/services/api-server/tests/unit/api_functions/celery/conftest.py +++ b/services/api-server/tests/unit/api_functions/celery/conftest.py @@ -14,21 +14,18 @@ TestWorkController, start_worker, ) -from celery.signals import ( # pylint: disable=no-name-in-module - worker_init, - worker_shutdown, -) -from celery.worker.worker import WorkController # pylint: disable=no-name-in-module -from celery_library.signals import on_worker_init, on_worker_shutdown +from celery.signals import worker_init, worker_shutdown +from celery_library.worker.signals import _worker_init_wrapper, _worker_shutdown_wrapper from pytest_mock import MockerFixture from pytest_simcore.helpers.monkeypatch_envs import delenvs_from_dict, setenvs_from_dict from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.fastapi.celery.app_server import FastAPIAppServer +from servicelib.tracing import TracingConfig from settings_library.redis import RedisSettings -from simcore_service_api_server.celery_worker.worker_main import setup_worker_tasks from simcore_service_api_server.clients import celery_task_manager from simcore_service_api_server.core.application import create_app from simcore_service_api_server.core.settings import ApplicationSettings +from simcore_service_api_server.modules.celery.worker.tasks import register_worker_tasks @pytest.fixture(scope="session") @@ -112,26 +109,28 @@ def add_worker_tasks() -> bool: @pytest.fixture async def with_api_server_celery_worker( - app_environment: EnvVarsDict, celery_app: Celery, - monkeypatch: pytest.MonkeyPatch, register_celery_tasks: Callable[[Celery], None], add_worker_tasks: bool, + monkeypatch: pytest.MonkeyPatch, ) -> AsyncIterator[TestWorkController]: - # Signals must be explicitily connected + tracing_config = TracingConfig.create( + tracing_settings=None, # disable tracing in tests + service_name="api-server-worker-test", + ) + # Signals must be explicitily connected monkeypatch.setenv("API_SERVER_WORKER_MODE", "true") app_settings = ApplicationSettings.create_from_envs() - app_server = FastAPIAppServer(app=create_app(app_settings)) - - def _on_worker_init_wrapper(sender: WorkController, **kwargs): - return on_worker_init(sender, app_server=app_server, **kwargs) + app_server = FastAPIAppServer(app=create_app(app_settings, tracing_config)) - worker_init.connect(_on_worker_init_wrapper) - worker_shutdown.connect(on_worker_shutdown) + _init_wrapper = _worker_init_wrapper(celery_app, lambda: app_server) + _shutdown_wrapper = _worker_shutdown_wrapper(celery_app) + worker_init.connect(_init_wrapper) + worker_shutdown.connect(_shutdown_wrapper) if add_worker_tasks: - setup_worker_tasks(celery_app) + register_worker_tasks(celery_app) register_celery_tasks(celery_app) with start_worker( diff --git a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py index 75088e96df30..7d5d197d1774 100644 --- a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py +++ b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py @@ -52,9 +52,6 @@ from simcore_service_api_server.api.dependencies.celery import ( get_task_manager, ) -from simcore_service_api_server.celery_worker.worker_tasks.functions_tasks import ( - run_function as run_function_task, -) from simcore_service_api_server.exceptions.backend_errors import BaseBackEndError from simcore_service_api_server.models.api_resources import JobLinks from simcore_service_api_server.models.domain.celery_models import ( @@ -67,6 +64,9 @@ JobPricingSpecification, NodeID, ) +from simcore_service_api_server.modules.celery.worker._functions_tasks import ( + run_function as run_function_task, +) from tenacity import ( AsyncRetrying, retry_if_exception_type, diff --git a/services/api-server/tests/unit/api_functions/test_api_routers_functions.py b/services/api-server/tests/unit/api_functions/test_api_routers_functions.py index fcc8b8e4fafc..a3ae47a0fcd2 100644 --- a/services/api-server/tests/unit/api_functions/test_api_routers_functions.py +++ b/services/api-server/tests/unit/api_functions/test_api_routers_functions.py @@ -46,12 +46,12 @@ from servicelib.rabbitmq import RabbitMQRPCClient from simcore_service_api_server._meta import API_VTAG from simcore_service_api_server.api.dependencies.authentication import Identity -from simcore_service_api_server.celery_worker.worker_tasks import functions_tasks from simcore_service_api_server.models.api_resources import JobLinks from simcore_service_api_server.models.domain.functions import ( PreRegisteredFunctionJobData, ) from simcore_service_api_server.models.schemas.jobs import JobInputs +from simcore_service_api_server.modules.celery.worker import _functions_tasks from simcore_service_api_server.services_rpc.wb_api_server import WbApiRpcClient _faker = Faker() @@ -426,13 +426,13 @@ def _get_app_server(celery_app: Any) -> FastAPI: app_server.app = app return app_server - mocker.patch.object(functions_tasks, "get_app_server", _get_app_server) + mocker.patch.object(_functions_tasks, "get_app_server", _get_app_server) def _get_rabbitmq_rpc_client(app: FastAPI) -> RabbitMQRPCClient: return mocker.MagicMock(spec=RabbitMQRPCClient) mocker.patch.object( - functions_tasks, "get_rabbitmq_rpc_client", _get_rabbitmq_rpc_client + _functions_tasks, "get_rabbitmq_rpc_client", _get_rabbitmq_rpc_client ) async def _get_wb_api_rpc_client(app: FastAPI) -> WbApiRpcClient: @@ -442,7 +442,7 @@ async def _get_wb_api_rpc_client(app: FastAPI) -> WbApiRpcClient: return WbApiRpcClient.get_from_app_state(app) mocker.patch.object( - functions_tasks, "get_wb_api_rpc_client", _get_wb_api_rpc_client + _functions_tasks, "get_wb_api_rpc_client", _get_wb_api_rpc_client ) def _default_side_effect( @@ -492,7 +492,7 @@ def _default_side_effect( function_job_id=fake_registered_project_function.uid, ) - job = await functions_tasks.run_function( + job = await _functions_tasks.run_function( task=MagicMock(spec=Task), task_key=TaskKey(_faker.uuid4()), user_identity=user_identity, diff --git a/services/docker-compose.yml b/services/docker-compose.yml index c326ac596bd4..8919cecf3e7c 100644 --- a/services/docker-compose.yml +++ b/services/docker-compose.yml @@ -168,7 +168,8 @@ services: API_SERVER_WORKER_NAME: "api-worker-{{.Node.Hostname}}-{{.Task.Slot}}-{{.Task.ID}}" API_SERVER_WORKER_MODE: "true" CELERY_CONCURRENCY: ${API_SERVER_CELERY_CONCURRENCY} - CELERY_QUEUES: "api_worker_queue" + CELERY_POOL: threads + CELERY_QUEUES: api_worker_queue networks: *api_server_networks @@ -1260,6 +1261,7 @@ services: STORAGE_WORKER_NAME: "sto-worker-{{.Node.Hostname}}-{{.Task.Slot}}-{{.Task.ID}}" STORAGE_WORKER_MODE: "true" CELERY_CONCURRENCY: 100 + CELERY_POOL: threads networks: *storage_networks sto-worker-cpu-bound: @@ -1272,7 +1274,8 @@ services: STORAGE_WORKER_NAME: "sto-worker-cpu-bound-{{.Node.Hostname}}-{{.Task.Slot}}-{{.Task.ID}}" STORAGE_WORKER_MODE: "true" CELERY_CONCURRENCY: 1 - CELERY_QUEUES: "cpu_bound" + CELERY_QUEUES: cpu_bound + CELERY_POOL: prefork networks: *storage_networks rabbit: diff --git a/services/storage/docker/boot.sh b/services/storage/docker/boot.sh index 6dd4e72f8e6c..2c9803f1b0ac 100755 --- a/services/storage/docker/boot.sh +++ b/services/storage/docker/boot.sh @@ -55,16 +55,16 @@ if [ "${STORAGE_WORKER_MODE}" = "true" ]; then --recursive \ -- \ celery \ - --app=simcore_service_storage.modules.celery.worker_main:app \ - worker --pool=threads \ + --app=simcore_service_storage.modules.celery.worker.main:app \ + worker --pool="${CELERY_POOL}" \ --loglevel="${SERVER_LOG_LEVEL}" \ --concurrency="${CELERY_CONCURRENCY}" \ --hostname="${STORAGE_WORKER_NAME}" \ --queues="${CELERY_QUEUES:-default}" else exec celery \ - --app=simcore_service_storage.modules.celery.worker_main:app \ - worker --pool=threads \ + --app=simcore_service_storage.modules.celery.worker.main:app \ + worker --pool="${CELERY_POOL}" \ --loglevel="${SERVER_LOG_LEVEL}" \ --concurrency="${CELERY_CONCURRENCY}" \ --hostname="${STORAGE_WORKER_NAME}" \ diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/_files.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/_files.py index 29c4cb72857a..4bf12b3dca25 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/_files.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/_files.py @@ -1,7 +1,7 @@ import logging from celery import Task # type: ignore[import-untyped] -from celery_library.utils import get_app_server +from celery_library.worker.app_server import get_app_server from models_library.api_schemas_storage.storage_schemas import ( FileUploadCompletionBody, ) diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py index 0e6ea4c3749f..0401d2400c89 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py @@ -2,7 +2,7 @@ from pathlib import Path from celery import Task # type: ignore[import-untyped] -from celery_library.utils import get_app_server +from celery_library.worker.app_server import get_app_server from models_library.projects_nodes_io import LocationID, StorageFileID from models_library.users import UserID from pydantic import ByteSize, TypeAdapter diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py index d40a90b084c8..85b3912c0221 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py @@ -5,7 +5,7 @@ from aws_library.s3._models import S3ObjectKey from celery import Task # type: ignore[import-untyped] -from celery_library.utils import get_app_server +from celery_library.worker.app_server import get_app_server from models_library.api_schemas_storage.search_async_jobs import SearchResultItem from models_library.api_schemas_storage.storage_schemas import ( FoldersBody, diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py index 5475b8eed8a4..52c81bc96213 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py @@ -25,7 +25,7 @@ _logger = logging.getLogger(__name__) -def setup_worker_tasks(app: Celery) -> None: +def register_worker_tasks(app: Celery) -> None: register_celery_types() register_pydantic_types( FileUploadCompletionBody, diff --git a/services/storage/src/simcore_service_storage/modules/celery/__init__.py b/services/storage/src/simcore_service_storage/modules/celery/__init__.py index 48e30e60a4f3..bdb434580b85 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/__init__.py +++ b/services/storage/src/simcore_service_storage/modules/celery/__init__.py @@ -1,7 +1,7 @@ import logging +from celery_library.app import create_app from celery_library.backends.redis import RedisTaskStore -from celery_library.common import create_app from celery_library.task_manager import CeleryTaskManager from celery_library.types import register_celery_types, register_pydantic_types from fastapi import FastAPI diff --git a/services/storage/src/simcore_service_storage/modules/celery/worker/__init__.py b/services/storage/src/simcore_service_storage/modules/celery/worker/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/services/storage/src/simcore_service_storage/modules/celery/worker/main.py b/services/storage/src/simcore_service_storage/modules/celery/worker/main.py new file mode 100644 index 000000000000..e6a455209706 --- /dev/null +++ b/services/storage/src/simcore_service_storage/modules/celery/worker/main.py @@ -0,0 +1,38 @@ +from celery_library.worker.app import create_worker_app +from servicelib.fastapi.celery.app_server import FastAPIAppServer +from servicelib.logging_utils import setup_loggers +from servicelib.tracing import TracingConfig + +from ....api._worker_tasks.tasks import register_worker_tasks +from ....core.application import create_app +from ....core.settings import ApplicationSettings + + +def get_app(): + _settings = ApplicationSettings.create_from_envs() + _tracing_config = TracingConfig.create( + tracing_settings=_settings.STORAGE_TRACING, + service_name="storage-celery-worker", + ) + + setup_loggers( + log_format_local_dev_enabled=_settings.STORAGE_LOG_FORMAT_LOCAL_DEV_ENABLED, + logger_filter_mapping=_settings.STORAGE_LOG_FILTER_MAPPING, + tracing_config=_tracing_config, + log_base_level=_settings.log_level, + noisy_loggers=None, + ) + + def _app_server_factory() -> FastAPIAppServer: + fastapi_app = create_app(_settings, tracing_config=_tracing_config) + return FastAPIAppServer(app=fastapi_app) + + assert _settings.STORAGE_CELERY # nosec + return create_worker_app( + _settings.STORAGE_CELERY, + register_worker_tasks_cb=register_worker_tasks, + app_server_factory_cb=_app_server_factory, + ) + + +app = get_app() diff --git a/services/storage/src/simcore_service_storage/modules/celery/worker_main.py b/services/storage/src/simcore_service_storage/modules/celery/worker_main.py deleted file mode 100644 index b82fff90abde..000000000000 --- a/services/storage/src/simcore_service_storage/modules/celery/worker_main.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Main application to be deployed in for example uvicorn.""" - -from celery.signals import worker_init, worker_shutdown # type: ignore[import-untyped] -from celery_library.common import create_app as create_celery_app -from celery_library.signals import ( - on_worker_init, - on_worker_shutdown, -) -from servicelib.fastapi.celery.app_server import FastAPIAppServer -from servicelib.logging_utils import setup_loggers -from servicelib.tracing import TracingConfig - -from ...api._worker_tasks.tasks import setup_worker_tasks -from ...core.application import create_app -from ...core.settings import ApplicationSettings - -_settings = ApplicationSettings.create_from_envs() -_tracing_config = TracingConfig.create( - tracing_settings=_settings.STORAGE_TRACING, - service_name="storage-celery-worker", -) - -setup_loggers( - log_format_local_dev_enabled=_settings.STORAGE_LOG_FORMAT_LOCAL_DEV_ENABLED, - logger_filter_mapping=_settings.STORAGE_LOG_FILTER_MAPPING, - tracing_config=_tracing_config, - log_base_level=_settings.log_level, - noisy_loggers=None, -) - - -assert _settings.STORAGE_CELERY # nosec -app = create_celery_app(_settings.STORAGE_CELERY) - -app_server = FastAPIAppServer(app=create_app(_settings, tracing_config=_tracing_config)) - - -def worker_init_wrapper(sender, **kwargs): - return on_worker_init(sender, app_server, **kwargs) - - -worker_init.connect(worker_init_wrapper) -worker_shutdown.connect(on_worker_shutdown) - - -setup_worker_tasks(app) diff --git a/services/storage/tests/conftest.py b/services/storage/tests/conftest.py index 1831fab085a9..5c02fa3c4a42 100644 --- a/services/storage/tests/conftest.py +++ b/services/storage/tests/conftest.py @@ -24,8 +24,7 @@ from celery import Celery from celery.contrib.testing.worker import TestWorkController, start_worker from celery.signals import worker_init, worker_shutdown -from celery.worker.worker import WorkController -from celery_library.signals import on_worker_init, on_worker_shutdown +from celery_library.worker.signals import _worker_init_wrapper, _worker_shutdown_wrapper from faker import Faker from fakeredis.aioredis import FakeRedis from fastapi import FastAPI @@ -70,7 +69,7 @@ from settings_library.rabbit import RabbitSettings from simcore_postgres_database.models.tokens import tokens from simcore_postgres_database.storage_models import file_meta_data, projects, users -from simcore_service_storage.api._worker_tasks.tasks import setup_worker_tasks +from simcore_service_storage.api._worker_tasks.tasks import register_worker_tasks from simcore_service_storage.core.application import create_app from simcore_service_storage.core.settings import ApplicationSettings from simcore_service_storage.datcore_dsm import DatCoreDataManager @@ -1008,32 +1007,41 @@ def _(celery_app: Celery) -> None: ... return _ +@pytest.fixture +def worker_app_settings( + app_settings: ApplicationSettings, +) -> ApplicationSettings: + worker_test_app_settings = app_settings.model_copy( + update={"STORAGE_WORKER_MODE": True}, deep=True + ) + print(f"{worker_test_app_settings.model_dump_json(indent=2)=}") + return worker_test_app_settings + + +_logger = logging.getLogger(__name__) + + @pytest.fixture async def with_storage_celery_worker( - app_environment: EnvVarsDict, celery_app: Celery, + worker_app_settings: ApplicationSettings, monkeypatch: pytest.MonkeyPatch, register_celery_tasks: Callable[[Celery], None], ) -> AsyncIterator[TestWorkController]: # Signals must be explicitily connected - monkeypatch.setenv("STORAGE_WORKER_MODE", "true") - app_settings = ApplicationSettings.create_from_envs() tracing_config = TracingConfig.create( tracing_settings=None, # disable tracing in tests service_name="storage-api", ) - app_server = FastAPIAppServer( - app=create_app(app_settings, tracing_config=tracing_config) - ) - - def _on_worker_init_wrapper(sender: WorkController, **_kwargs): - return on_worker_init(sender, app_server, **_kwargs) + app_server = FastAPIAppServer(app=create_app(worker_app_settings, tracing_config)) - worker_init.connect(_on_worker_init_wrapper) - worker_shutdown.connect(on_worker_shutdown) + init_wrapper = _worker_init_wrapper(celery_app, lambda: app_server) + worker_init.connect(init_wrapper, weak=False) + shutdown_wrapper = _worker_shutdown_wrapper(celery_app) + worker_shutdown.connect(shutdown_wrapper, weak=False) - setup_worker_tasks(celery_app) + register_worker_tasks(celery_app) register_celery_tasks(celery_app) with start_worker( @@ -1046,6 +1054,9 @@ def _on_worker_init_wrapper(sender: WorkController, **_kwargs): ) as worker: yield worker + worker_init.disconnect(init_wrapper) + worker_shutdown.disconnect(shutdown_wrapper) + @pytest.fixture async def storage_rabbitmq_rpc_client( diff --git a/services/web/server/src/simcore_service_webserver/celery/_task_manager.py b/services/web/server/src/simcore_service_webserver/celery/_task_manager.py index adc06a8e219a..ef136098826c 100644 --- a/services/web/server/src/simcore_service_webserver/celery/_task_manager.py +++ b/services/web/server/src/simcore_service_webserver/celery/_task_manager.py @@ -2,8 +2,8 @@ from typing import Final from aiohttp import web +from celery_library.app import create_app from celery_library.backends.redis import RedisTaskStore -from celery_library.common import create_app from celery_library.task_manager import CeleryTaskManager from celery_library.types import register_celery_types from servicelib.celery.task_manager import TaskManager