Skip to content

Commit 32be5a5

Browse files
use async redis client
1 parent 96791e7 commit 32be5a5

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

services/storage/src/simcore_service_storage/modules/celery/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ._celery_types import register_celery_types
88
from ._common import create_app
99
from .client import CeleryTaskQueueClient
10+
from .utils import set_fastapi_app
1011

1112
_logger = logging.getLogger(__name__)
1213

@@ -17,6 +18,7 @@ async def on_startup() -> None:
1718
assert celery_settings # nosec
1819
celery_app = create_app(celery_settings)
1920
app.state.celery_client = CeleryTaskQueueClient(celery_app)
21+
set_fastapi_app(celery_app, app)
2022

2123
register_celery_types()
2224

services/storage/src/simcore_service_storage/modules/celery/client.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
from models_library.progress_bar import ProgressReport
1212
from pydantic import ValidationError
1313
from servicelib.logging_utils import log_context
14+
from servicelib.redis._client import RedisClientSDK
1415

16+
from ..redis import get_redis_client
1517
from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID
18+
from .utils import get_fastapi_app
1619

1720
_logger = logging.getLogger(__name__)
1821

@@ -116,21 +119,27 @@ def get_task_status(
116119
progress_report=self._get_progress_report(task_context, task_uuid),
117120
)
118121

119-
def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
122+
async def _get_completed_task_uuids(
123+
self, task_context: TaskContext
124+
) -> set[TaskUUID]:
120125
search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context)
121-
redis = self._celery_app.backend.client
122-
if hasattr(redis, "keys") and (keys := redis.keys(search_key + "*")):
123-
return {
124-
TaskUUID(
125-
f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
126-
)
127-
for key in keys
128-
}
129-
return set()
126+
redis_client_sdk = get_redis_client(get_fastapi_app(self._celery_app))
127+
assert isinstance(redis_client_sdk, RedisClientSDK) # nosec
128+
redis = redis_client_sdk.redis
129+
all_keys = set()
130+
async for keys in redis.scan_iter(search_key + "*"):
131+
all_keys.add(
132+
{
133+
TaskUUID(
134+
f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
135+
)
136+
for key in keys
137+
}
138+
)
139+
return all_keys
130140

131-
@make_async()
132-
def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
133-
task_uuids = self._get_completed_task_uuids(task_context)
141+
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
142+
task_uuids = await self._get_completed_task_uuids(task_context)
134143

135144
task_id_prefix = _build_task_id_prefix(task_context)
136145
inspect = self._celery_app.control.inspect()

0 commit comments

Comments
 (0)