|
| 1 | +import logging |
| 2 | +from asyncio import AbstractEventLoop |
| 3 | + |
| 4 | +from fastapi import FastAPI |
| 5 | +from servicelib.redis._client import RedisClientSDK |
| 6 | +from settings_library.redis import RedisDatabase |
| 7 | + |
| 8 | +from ..._meta import APP_NAME |
| 9 | +from ...core.settings import get_application_settings |
| 10 | +from ._celery_types import register_celery_types |
| 11 | +from ._common import create_app |
| 12 | +from .backends._redis import RedisTaskInfoStore |
| 13 | +from .client import CeleryTaskClient |
| 14 | + |
| 15 | +_logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +def setup_celery_client(app: FastAPI) -> None: |
| 19 | + async def on_startup() -> None: |
| 20 | + application_settings = get_application_settings(app) |
| 21 | + celery_settings = application_settings.STORAGE_CELERY |
| 22 | + assert celery_settings # nosec |
| 23 | + celery_app = create_app(celery_settings) |
| 24 | + redis_client_sdk = RedisClientSDK( |
| 25 | + celery_settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn( |
| 26 | + RedisDatabase.CELERY_TASKS |
| 27 | + ), |
| 28 | + client_name=f"{APP_NAME}.celery_tasks", |
| 29 | + ) |
| 30 | + |
| 31 | + app.state.celery_client = CeleryTaskClient( |
| 32 | + celery_app, |
| 33 | + celery_settings, |
| 34 | + RedisTaskInfoStore(redis_client_sdk), |
| 35 | + ) |
| 36 | + |
| 37 | + register_celery_types() |
| 38 | + |
| 39 | + app.add_event_handler("startup", on_startup) |
| 40 | + |
| 41 | + |
| 42 | +def get_celery_client(app: FastAPI) -> CeleryTaskClient: |
| 43 | + assert hasattr(app.state, "celery_client") # nosec |
| 44 | + celery_client = app.state.celery_client |
| 45 | + assert isinstance(celery_client, CeleryTaskClient) |
| 46 | + return celery_client |
| 47 | + |
| 48 | + |
| 49 | +def get_event_loop(app: FastAPI) -> AbstractEventLoop: |
| 50 | + event_loop = app.state.event_loop |
| 51 | + assert isinstance(event_loop, AbstractEventLoop) |
| 52 | + return event_loop |
| 53 | + |
| 54 | + |
| 55 | +def set_event_loop(app: FastAPI, event_loop: AbstractEventLoop) -> None: |
| 56 | + app.state.event_loop = event_loop |
0 commit comments