diff --git a/packages/celery-library/src/celery_library/signals.py b/packages/celery-library/src/celery_library/signals.py index 09a29f3b0dc..dd5bf047e65 100644 --- a/packages/celery-library/src/celery_library/signals.py +++ b/packages/celery-library/src/celery_library/signals.py @@ -4,7 +4,7 @@ from celery import Celery # type: ignore[import-untyped] from celery.worker.worker import WorkController # type: ignore[import-untyped] -from servicelib.celery.app_server import STARTUP_TIMEOUT, BaseAppServer +from servicelib.celery.app_server import BaseAppServer from servicelib.logging_utils import log_context from settings_library.celery import CelerySettings @@ -26,8 +26,6 @@ def _init(startup_complete_event: threading.Event) -> None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - shutdown_event = asyncio.Event() - async def _setup_task_manager(): assert sender.app # nosec assert isinstance(sender.app, Celery) # nosec @@ -42,9 +40,7 @@ async def _setup_task_manager(): app_server.event_loop = loop loop.run_until_complete(_setup_task_manager()) - loop.run_until_complete( - app_server.startup(startup_complete_event, shutdown_event) - ) + loop.run_until_complete(app_server.lifespan(startup_complete_event)) thread = threading.Thread( group=None, @@ -55,7 +51,7 @@ async def _setup_task_manager(): ) thread.start() - startup_complete_event.wait(STARTUP_TIMEOUT * 1.1) + startup_complete_event.wait() def on_worker_shutdown(sender, **_kwargs) -> None: @@ -63,4 +59,4 @@ def on_worker_shutdown(sender, **_kwargs) -> None: assert isinstance(sender.app, Celery) app_server = get_app_server(sender.app) - app_server.event_loop.run_until_complete(app_server.shutdown()) + app_server.shutdown_event.set() diff --git a/packages/celery-library/tests/conftest.py b/packages/celery-library/tests/conftest.py index c513d262d7f..5ecde6333f9 100644 --- a/packages/celery-library/tests/conftest.py +++ b/packages/celery-library/tests/conftest.py @@ -2,6 +2,7 @@ # pylint: disable=unused-argument import datetime +import threading from collections.abc import AsyncIterator, Callable from functools import partial from typing import Any @@ -30,10 +31,7 @@ class FakeAppServer(BaseAppServer): - async def on_startup(self) -> None: - pass - - async def on_shutdown(self) -> None: + async def lifespan(self, startup_completed_event: threading.Event) -> None: pass diff --git a/packages/service-library/src/servicelib/celery/app_server.py b/packages/service-library/src/servicelib/celery/app_server.py index 0c55c0ed919..dbddd9bfee2 100644 --- a/packages/service-library/src/servicelib/celery/app_server.py +++ b/packages/service-library/src/servicelib/celery/app_server.py @@ -1,21 +1,18 @@ import asyncio -import datetime import threading from abc import ABC, abstractmethod from asyncio import AbstractEventLoop -from typing import Final, Generic, TypeVar +from typing import Generic, TypeVar from servicelib.celery.task_manager import TaskManager -STARTUP_TIMEOUT: Final[float] = datetime.timedelta(minutes=1).total_seconds() - T = TypeVar("T") class BaseAppServer(ABC, Generic[T]): def __init__(self, app: T) -> None: self._app: T = app - self._shutdown_event: asyncio.Event | None = None + self._shutdown_event: asyncio.Event = asyncio.Event() @property def app(self) -> T: @@ -29,6 +26,10 @@ def event_loop(self) -> AbstractEventLoop: def event_loop(self, loop: AbstractEventLoop) -> None: self._event_loop = loop + @property + def shutdown_event(self) -> asyncio.Event: + return self._shutdown_event + @property def task_manager(self) -> TaskManager: return self._task_manager @@ -38,23 +39,8 @@ def task_manager(self, manager: TaskManager) -> None: self._task_manager = manager @abstractmethod - async def on_startup(self) -> None: - raise NotImplementedError - - async def startup( - self, completed_event: threading.Event, shutdown_event: asyncio.Event + async def lifespan( + self, + startup_completed_event: threading.Event, ) -> None: - self._shutdown_event = shutdown_event - completed_event.set() - await self.on_startup() - await self._shutdown_event.wait() - - @abstractmethod - async def on_shutdown(self) -> None: raise NotImplementedError - - async def shutdown(self) -> None: - if self._shutdown_event is not None: - self._shutdown_event.set() - - await self.on_shutdown() diff --git a/packages/service-library/src/servicelib/fastapi/celery/app_server.py b/packages/service-library/src/servicelib/fastapi/celery/app_server.py index 3a5ab7303de..e1a1d3255ac 100644 --- a/packages/service-library/src/servicelib/fastapi/celery/app_server.py +++ b/packages/service-library/src/servicelib/fastapi/celery/app_server.py @@ -1,4 +1,6 @@ -from datetime import timedelta +import datetime +import logging +import threading from typing import Final from asgi_lifespan import LifespanManager @@ -6,8 +8,9 @@ from ...celery.app_server import BaseAppServer -_SHUTDOWN_TIMEOUT: Final[float] = timedelta(seconds=10).total_seconds() -_STARTUP_TIMEOUT: Final[float] = timedelta(minutes=1).total_seconds() +_SHUTDOWN_TIMEOUT: Final[float] = datetime.timedelta(seconds=10).total_seconds() + +_logger = logging.getLogger(__name__) class FastAPIAppServer(BaseAppServer[FastAPI]): @@ -15,15 +18,12 @@ def __init__(self, app: FastAPI): super().__init__(app) self._lifespan_manager: LifespanManager | None = None - async def on_startup(self) -> None: - self._lifespan_manager = LifespanManager( + async def lifespan(self, startup_completed_event: threading.Event) -> None: + async with LifespanManager( self.app, - startup_timeout=_STARTUP_TIMEOUT, + startup_timeout=None, # waits for full app initialization (DB migrations, etc.) shutdown_timeout=_SHUTDOWN_TIMEOUT, - ) - await self._lifespan_manager.__aenter__() - - async def on_shutdown(self) -> None: - if self._lifespan_manager is None: - return - await self._lifespan_manager.__aexit__(None, None, None) + ): + _logger.info("fastapi app initialized") + startup_completed_event.set() + await self.shutdown_event.wait() # NOTE: wait here until shutdown is requested