Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ._celery_types import register_celery_types
from ._common import create_app
from .client import CeleryTaskQueueClient
from .utils import set_fastapi_app

_logger = logging.getLogger(__name__)

Expand All @@ -17,6 +18,7 @@ async def on_startup() -> None:
assert celery_settings # nosec
celery_app = create_app(celery_settings)
app.state.celery_client = CeleryTaskQueueClient(celery_app)
set_fastapi_app(celery_app, app)

register_celery_types()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from models_library.progress_bar import ProgressReport
from pydantic import ValidationError
from servicelib.logging_utils import log_context
from servicelib.redis._client import RedisClientSDK

from ..redis import get_redis_client
from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID
from .utils import get_fastapi_app

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,21 +119,27 @@ def get_task_status(
progress_report=self._get_progress_report(task_context, task_uuid),
)

def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
async def _get_completed_task_uuids(
self, task_context: TaskContext
) -> set[TaskUUID]:
search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context)
redis = self._celery_app.backend.client
if hasattr(redis, "keys") and (keys := redis.keys(search_key + "*")):
return {
TaskUUID(
f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
)
for key in keys
}
return set()
redis_client_sdk = get_redis_client(get_fastapi_app(self._celery_app))
assert isinstance(redis_client_sdk, RedisClientSDK) # nosec
redis = redis_client_sdk.redis
all_keys = set()
async for keys in redis.scan_iter(search_key + "*"):
all_keys.add(
{
TaskUUID(
f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
)
for key in keys
}
)
return all_keys

@make_async()
def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
task_uuids = self._get_completed_task_uuids(task_context)
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
task_uuids = await self._get_completed_task_uuids(task_context)

task_id_prefix = _build_task_id_prefix(task_context)
inspect = self._celery_app.control.inspect()
Expand Down
Loading