|
1 | | -from celery_library.common import create_app, create_task_manager |
| 1 | +import logging |
| 2 | + |
| 3 | +from celery_library.backends._redis import RedisTaskInfoStore |
| 4 | +from celery_library.common import create_app |
| 5 | +from celery_library.task_manager import CeleryTaskManager |
2 | 6 | from celery_library.types import register_celery_types, register_pydantic_types |
3 | 7 | from fastapi import FastAPI |
| 8 | +from servicelib.logging_utils import log_context |
| 9 | +from servicelib.redis import RedisClientSDK |
4 | 10 | from settings_library.celery import CelerySettings |
| 11 | +from settings_library.redis import RedisDatabase |
5 | 12 |
|
6 | 13 | from ..celery_worker.worker_tasks.tasks import pydantic_types_to_register |
7 | 14 |
|
| 15 | +_logger = logging.getLogger(__name__) |
| 16 | + |
8 | 17 |
|
9 | | -def setup_task_manager(app: FastAPI, celery_settings: CelerySettings) -> None: |
| 18 | +def setup_task_manager(app: FastAPI, settings: CelerySettings) -> None: |
10 | 19 | async def on_startup() -> None: |
11 | | - app.state.task_manager = await create_task_manager( |
12 | | - create_app(celery_settings), celery_settings |
13 | | - ) |
| 20 | + with log_context(_logger, logging.INFO, "Setting up Celery"): |
| 21 | + redis_client_sdk = RedisClientSDK( |
| 22 | + settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn( |
| 23 | + RedisDatabase.CELERY_TASKS |
| 24 | + ), |
| 25 | + client_name="api_server_celery_tasks", |
| 26 | + ) |
| 27 | + app.state.celery_tasks_redis_client_sdk = redis_client_sdk |
| 28 | + await redis_client_sdk.setup() |
| 29 | + |
| 30 | + app.state.task_manager = CeleryTaskManager( |
| 31 | + create_app(settings), |
| 32 | + settings, |
| 33 | + RedisTaskInfoStore(redis_client_sdk), |
| 34 | + ) |
| 35 | + |
| 36 | + register_celery_types() |
| 37 | + register_pydantic_types(*pydantic_types_to_register) |
14 | 38 |
|
15 | | - register_celery_types() |
16 | | - register_pydantic_types(*pydantic_types_to_register) |
| 39 | + async def on_shutdown() -> None: |
| 40 | + with log_context(_logger, logging.INFO, "Shutting down Celery"): |
| 41 | + redis_client_sdk: RedisClientSDK | None = ( |
| 42 | + app.state.celery_tasks_redis_client_sdk |
| 43 | + ) |
| 44 | + if redis_client_sdk: |
| 45 | + await redis_client_sdk.shutdown() |
17 | 46 |
|
18 | 47 | app.add_event_handler("startup", on_startup) |
| 48 | + app.add_event_handler("shutdown", on_shutdown) |
0 commit comments