Skip to content

Commit 9db3d5b

Browse files
move set_progress
1 parent 59a41cd commit 9db3d5b

File tree

12 files changed

+56
-50
lines changed

12 files changed

+56
-50
lines changed

services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ async def _task_progress_cb(
2626
) -> None:
2727
worker = get_celery_worker(task.app)
2828
assert task.name # nosec
29-
worker.set_task_progress(
30-
task_name=task.name,
29+
await worker.set_progress(
3130
task_id=task_id,
3231
report=report,
3332
)
@@ -88,7 +87,7 @@ async def export_data(
8887

8988
async def _progress_cb(report: ProgressReport) -> None:
9089
assert task.name # nosec
91-
get_celery_worker(task.app).set_task_progress(task.name, task_id, report)
90+
await get_celery_worker(task.app).set_progress(task_id, report)
9291
_logger.debug("'%s' progress %s", task_id, report.percent_value)
9392

9493
async with ProgressBarData(

services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Final
44

55
from celery.result import AsyncResult # type: ignore[import-untyped]
6+
from models_library.progress_bar import ProgressReport
67
from servicelib.redis._client import RedisClientSDK
78

89
from ..models import TaskContext, TaskID, TaskMetadata, TaskUUID, build_task_id_prefix
@@ -12,6 +13,7 @@
1213
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
1314
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
1415
_CELERY_TASK_METADATA_KEY: Final[str] = "metadata"
16+
_CELERY_TASK_PROGRESS_KEY: Final[str] = "progress"
1517

1618
_logger = logging.getLogger(__name__)
1719

@@ -33,6 +35,10 @@ async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None:
3335
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
3436
return TaskMetadata.model_validate_json(result) if result else None
3537

38+
async def get_progress(self, task_id: TaskID) -> ProgressReport | None:
39+
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
40+
return ProgressReport.model_validate_json(result) if result else None
41+
3642
async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
3743
search_key = (
3844
_CELERY_TASK_INFO_PREFIX
@@ -68,3 +74,10 @@ async def set_metadata(
6874
_build_key(task_id),
6975
expiry,
7076
)
77+
78+
async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None:
79+
await self._redis_client_sdk.redis.hset(
80+
name=_build_key(task_id),
81+
key=_CELERY_TASK_PROGRESS_KEY,
82+
value=report.model_dump_json(),
83+
) # type: ignore

services/storage/src/simcore_service_storage/modules/celery/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class TaskMetadata(BaseModel):
4949
class TaskInfoStore(Protocol):
5050
async def exists(self, task_id: TaskID) -> bool: ...
5151

52+
async def get_progress(self, task_id: TaskID) -> ProgressReport | None: ...
53+
5254
async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None: ...
5355

5456
async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...
@@ -59,6 +61,8 @@ async def set_metadata(
5961
self, task_id: TaskID, task_data: TaskMetadata, expiry: timedelta
6062
) -> None: ...
6163

64+
async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None: ...
65+
6266

6367
class TaskStatus(BaseModel):
6468
task_uuid: TaskUUID

services/storage/src/simcore_service_storage/modules/celery/signals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
set_celery_worker,
1818
set_fastapi_app,
1919
)
20-
from ...modules.celery.worker import CeleryTaskQueueWorker
20+
from ...modules.celery.worker import CeleryTaskWorker
2121

2222
_logger = logging.getLogger(__name__)
2323

@@ -54,7 +54,7 @@ async def lifespan(
5454
set_event_loop(fastapi_app, loop)
5555

5656
set_fastapi_app(sender.app, fastapi_app)
57-
set_celery_worker(sender.app, CeleryTaskQueueWorker(sender.app))
57+
set_celery_worker(sender.app, CeleryTaskWorker(sender.app))
5858
loop.run_until_complete(lifespan(startup_complete_event, shutdown_event))
5959

6060
thread = threading.Thread(

services/storage/src/simcore_service_storage/modules/celery/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from celery import Celery # type: ignore[import-untyped]
22
from fastapi import FastAPI
33

4-
from .worker import CeleryTaskQueueWorker
4+
from .worker import CeleryTaskWorker
55

66
_WORKER_KEY = "celery_worker"
77
_FASTAPI_APP_KEY = "fastapi_app"
88

99

10-
def set_celery_worker(celery_app: Celery, worker: CeleryTaskQueueWorker) -> None:
10+
def set_celery_worker(celery_app: Celery, worker: CeleryTaskWorker) -> None:
1111
celery_app.conf[_WORKER_KEY] = worker
1212

1313

14-
def get_celery_worker(celery_app: Celery) -> CeleryTaskQueueWorker:
14+
def get_celery_worker(celery_app: Celery) -> CeleryTaskWorker:
1515
worker = celery_app.conf[_WORKER_KEY]
16-
assert isinstance(worker, CeleryTaskQueueWorker)
16+
assert isinstance(worker, CeleryTaskWorker)
1717
return worker
1818

1919

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,19 @@
11
import logging
2+
from dataclasses import dataclass
23

3-
from celery import Celery # type: ignore[import-untyped]
44
from models_library.progress_bar import ProgressReport
5-
from servicelib.logging_utils import log_context
65

7-
from ..celery.models import TaskID
6+
from ..celery.models import TaskID, TaskInfoStore
87

98
_logger = logging.getLogger(__name__)
109

1110

12-
class CeleryTaskQueueWorker:
13-
def __init__(self, celery_app: Celery) -> None:
14-
self.celery_app = celery_app
11+
@dataclass
12+
class CeleryTaskWorker:
13+
_task_info_store: TaskInfoStore
1514

16-
def set_task_progress(
17-
self, task_name: str, task_id: TaskID, report: ProgressReport
18-
) -> None:
19-
with log_context(
20-
_logger,
21-
logging.DEBUG,
22-
msg=f"Setting progress for {task_name}: {report.model_dump_json()}",
23-
):
24-
self.celery_app.tasks[task_name].update_state(
25-
task_id=task_id,
26-
state="RUNNING",
27-
meta=report.model_dump(mode="json"),
28-
)
15+
async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None:
16+
await self._task_info_store.set_progress(
17+
task_id=task_id,
18+
report=report,
19+
)

services/storage/tests/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
on_worker_shutdown,
7777
)
7878
from simcore_service_storage.modules.celery.utils import get_celery_worker
79-
from simcore_service_storage.modules.celery.worker import CeleryTaskQueueWorker
79+
from simcore_service_storage.modules.celery.worker import CeleryTaskWorker
8080
from simcore_service_storage.modules.s3 import get_s3_client
8181
from simcore_service_storage.simcore_s3_dsm import SimcoreS3DataManager
8282
from sqlalchemy import literal_column
@@ -365,7 +365,7 @@ def upload_file(
365365
create_upload_file_link_v2: Callable[..., Awaitable[FileUploadSchema]],
366366
create_file_of_size: Callable[[ByteSize, str | None], Path],
367367
create_simcore_file_id: Callable[[ProjectID, NodeID, str], SimcoreS3FileID],
368-
with_storage_celery_worker: CeleryTaskQueueWorker,
368+
with_storage_celery_worker: CeleryTaskWorker,
369369
) -> Callable[
370370
[ByteSize, str, SimcoreS3FileID | None], Awaitable[tuple[Path, SimcoreS3FileID]]
371371
]:
@@ -480,7 +480,7 @@ async def create_empty_directory(
480480
create_simcore_file_id: Callable[[ProjectID, NodeID, str], SimcoreS3FileID],
481481
create_upload_file_link_v2: Callable[..., Awaitable[FileUploadSchema]],
482482
client: httpx.AsyncClient,
483-
with_storage_celery_worker: CeleryTaskQueueWorker,
483+
with_storage_celery_worker: CeleryTaskWorker,
484484
) -> Callable[[str, ProjectID, NodeID], Awaitable[SimcoreS3FileID]]:
485485
async def _directory_creator(
486486
dir_name: str, project_id: ProjectID, node_id: NodeID
@@ -1024,7 +1024,7 @@ async def with_storage_celery_worker_controller(
10241024
@pytest.fixture
10251025
def with_storage_celery_worker(
10261026
with_storage_celery_worker_controller: TestWorkController,
1027-
) -> CeleryTaskQueueWorker:
1027+
) -> CeleryTaskWorker:
10281028
assert isinstance(with_storage_celery_worker_controller.app, Celery)
10291029
return get_celery_worker(with_storage_celery_worker_controller.app)
10301030

services/storage/tests/unit/test_async_jobs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from simcore_service_storage.modules.celery import get_celery_client
3131
from simcore_service_storage.modules.celery._task import register_task
3232
from simcore_service_storage.modules.celery.models import TaskID
33-
from simcore_service_storage.modules.celery.worker import CeleryTaskQueueWorker
33+
from simcore_service_storage.modules.celery.worker import CeleryTaskWorker
3434
from tenacity import (
3535
AsyncRetrying,
3636
retry_if_exception_type,
@@ -199,7 +199,7 @@ async def test_async_jobs_workflow(
199199
initialized_app: FastAPI,
200200
register_rpc_routes: None,
201201
storage_rabbitmq_rpc_client: RabbitMQRPCClient,
202-
with_storage_celery_worker: CeleryTaskQueueWorker,
202+
with_storage_celery_worker: CeleryTaskWorker,
203203
user_id: UserID,
204204
product_name: ProductName,
205205
exposed_rpc_start: str,
@@ -247,7 +247,7 @@ async def test_async_jobs_cancel(
247247
initialized_app: FastAPI,
248248
register_rpc_routes: None,
249249
storage_rabbitmq_rpc_client: RabbitMQRPCClient,
250-
with_storage_celery_worker: CeleryTaskQueueWorker,
250+
with_storage_celery_worker: CeleryTaskWorker,
251251
user_id: UserID,
252252
product_name: ProductName,
253253
exposed_rpc_start: str,
@@ -328,7 +328,7 @@ async def test_async_jobs_raises(
328328
initialized_app: FastAPI,
329329
register_rpc_routes: None,
330330
storage_rabbitmq_rpc_client: RabbitMQRPCClient,
331-
with_storage_celery_worker: CeleryTaskQueueWorker,
331+
with_storage_celery_worker: CeleryTaskWorker,
332332
user_id: UserID,
333333
product_name: ProductName,
334334
exposed_rpc_start: str,

services/storage/tests/unit/test_handlers_files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from servicelib.aiohttp import status
5454
from simcore_service_storage.constants import S3_UNDEFINED_OR_EXTERNAL_MULTIPART_ID
5555
from simcore_service_storage.models import FileDownloadResponse, S3BucketName, UploadID
56-
from simcore_service_storage.modules.celery.worker import CeleryTaskQueueWorker
56+
from simcore_service_storage.modules.celery.worker import CeleryTaskWorker
5757
from simcore_service_storage.simcore_s3_dsm import SimcoreS3DataManager
5858
from sqlalchemy.ext.asyncio import AsyncEngine
5959
from tenacity.asyncio import AsyncRetrying
@@ -683,7 +683,7 @@ async def test_upload_real_file_with_s3_client(
683683
node_id: NodeID,
684684
faker: Faker,
685685
s3_client: S3Client,
686-
with_storage_celery_worker: CeleryTaskQueueWorker,
686+
with_storage_celery_worker: CeleryTaskWorker,
687687
):
688688
file_size = TypeAdapter(ByteSize).validate_python("500Mib")
689689
file_name = faker.file_name()

services/storage/tests/unit/test_modules_celery.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
get_celery_worker,
3030
get_fastapi_app,
3131
)
32-
from simcore_service_storage.modules.celery.worker import CeleryTaskQueueWorker
32+
from simcore_service_storage.modules.celery.worker import CeleryTaskWorker
3333
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed
3434

3535
_logger = logging.getLogger(__name__)
@@ -41,7 +41,7 @@
4141
@pytest.fixture
4242
def celery_client(
4343
initialized_app: FastAPI,
44-
with_storage_celery_worker: CeleryTaskQueueWorker,
44+
with_storage_celery_worker: CeleryTaskWorker,
4545
) -> CeleryTaskQueueClient:
4646
return get_celery_client(initialized_app)
4747

@@ -56,8 +56,7 @@ def sleep_for(seconds: float) -> None:
5656

5757
for n, file in enumerate(files, start=1):
5858
with log_context(_logger, logging.INFO, msg=f"Processing file {file}"):
59-
worker.set_task_progress(
60-
task_name=task_name,
59+
await worker.set_progress(
6160
task_id=task_id,
6261
report=ProgressReport(actual_value=n / len(files)),
6362
)

0 commit comments

Comments
 (0)