|
11 | 11 | from models_library.progress_bar import ProgressReport |
12 | 12 | from pydantic import ValidationError |
13 | 13 | from servicelib.logging_utils import log_context |
| 14 | +from servicelib.redis._client import RedisClientSDK |
14 | 15 |
|
| 16 | +from ..redis import get_redis_client |
15 | 17 | from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID |
| 18 | +from .utils import get_fastapi_app |
16 | 19 |
|
17 | 20 | _logger = logging.getLogger(__name__) |
18 | 21 |
|
@@ -116,21 +119,27 @@ def get_task_status( |
116 | 119 | progress_report=self._get_progress_report(task_context, task_uuid), |
117 | 120 | ) |
118 | 121 |
|
119 | | - def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: |
| 122 | + async def _get_completed_task_uuids( |
| 123 | + self, task_context: TaskContext |
| 124 | + ) -> set[TaskUUID]: |
120 | 125 | search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context) |
121 | | - redis = self._celery_app.backend.client |
122 | | - if hasattr(redis, "keys") and (keys := redis.keys(search_key + "*")): |
123 | | - return { |
124 | | - TaskUUID( |
125 | | - f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}" |
126 | | - ) |
127 | | - for key in keys |
128 | | - } |
129 | | - return set() |
| 126 | + redis_client_sdk = get_redis_client(get_fastapi_app(self._celery_app)) |
| 127 | + assert isinstance(redis_client_sdk, RedisClientSDK) # nosec |
| 128 | + redis = redis_client_sdk.redis |
| 129 | + all_keys = set() |
| 130 | + async for keys in redis.scan_iter(search_key + "*"): |
| 131 | + all_keys.add( |
| 132 | + { |
| 133 | + TaskUUID( |
| 134 | + f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}" |
| 135 | + ) |
| 136 | + for key in keys |
| 137 | + } |
| 138 | + ) |
| 139 | + return all_keys |
130 | 140 |
|
131 | | - @make_async() |
132 | | - def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: |
133 | | - task_uuids = self._get_completed_task_uuids(task_context) |
| 141 | + async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]: |
| 142 | + task_uuids = await self._get_completed_task_uuids(task_context) |
134 | 143 |
|
135 | 144 | task_id_prefix = _build_task_id_prefix(task_context) |
136 | 145 | inspect = self._celery_app.control.inspect() |
|
0 commit comments