Skip to content

Commit 69d24ca

Browse files
fix tests
1 parent 5ccbded commit 69d24ca

File tree

4 files changed

+34
-7
lines changed

4 files changed

+34
-7
lines changed

packages/celery-library/src/celery_library/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ async def submit_task(
5454
self._celery_app.send_task(
5555
task_metadata.name,
5656
task_id=task_id,
57-
kwargs={"task_id": task_id} | task_params,
57+
kwargs=task_params,
5858
queue=task_metadata.queue.value,
5959
)
6060

packages/celery-library/src/celery_library/task.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ class TaskAbortedError(Exception): ...
4040
def _async_task_wrapper(
4141
app: Celery,
4242
) -> Callable[
43-
[Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]]],
43+
[Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]]],
4444
Callable[Concatenate[AbortableTask, P], R],
4545
]:
4646
def decorator(
47-
coro: Callable[Concatenate[AbortableTask, P], Coroutine[Any, Any, R]],
47+
coro: Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]],
4848
) -> Callable[Concatenate[AbortableTask, P], R]:
4949
@wraps(coro)
5050
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
@@ -56,7 +56,7 @@ async def run_task(task_id: TaskID) -> R:
5656
try:
5757
async with asyncio.TaskGroup() as tg:
5858
main_task = tg.create_task(
59-
coro(task, *args, **kwargs),
59+
coro(task, task_id, *args, **kwargs),
6060
)
6161

6262
async def abort_monitor():
@@ -205,5 +205,4 @@ def register_task( # type: ignore[misc]
205205
bind=True,
206206
base=AbortableTask,
207207
time_limit=None if timeout is None else timeout.total_seconds(),
208-
pydantic=True,
209208
)(wrapped_fn)

packages/celery-library/src/celery_library/types.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from functools import partial
12
from pathlib import Path
3+
from typing import Any
24

3-
from kombu.utils.json import register_type # type: ignore[import-untyped]
5+
from kombu.utils.json import register_type
6+
from pydantic import BaseModel # type: ignore[import-untyped]
47

58

69
def _path_encoder(obj):
@@ -20,6 +23,14 @@ def _class_full_name(clz: type) -> str:
2023
return ".".join([clz.__module__, clz.__qualname__])
2124

2225

26+
def _pydantic_model_encoder(obj: BaseModel, *args, **kwargs) -> dict[str, Any]:
27+
return obj.model_dump(*args, **kwargs, mode="json")
28+
29+
30+
def _pydantic_model_decoder(clz: type[BaseModel], data: dict[str, Any]) -> BaseModel:
31+
return clz(**data)
32+
33+
2334
def register_celery_types() -> None:
2435
register_type(
2536
Path,
@@ -28,3 +39,13 @@ def register_celery_types() -> None:
2839
_path_decoder,
2940
)
3041
register_type(set, _class_full_name(set), encoder=list, decoder=set)
42+
43+
44+
def register_pydantic_types(*models: type[BaseModel]) -> None:
45+
for model in models:
46+
register_type(
47+
model,
48+
_class_full_name(model),
49+
encoder=_pydantic_model_encoder,
50+
decoder=partial(_pydantic_model_decoder, model),
51+
)

services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22

33
from celery import Celery # type: ignore[import-untyped]
44
from celery_library.task import register_task
5-
from celery_library.types import register_celery_types
5+
from celery_library.types import register_celery_types, register_pydantic_types
66
from models_library.api_schemas_storage.export_data_async_jobs import AccessRightError
7+
from models_library.api_schemas_storage.storage_schemas import (
8+
FileUploadCompletionBody,
9+
FoldersBody,
10+
)
711
from servicelib.logging_utils import log_context
812

13+
from ...models import FileMetaData
914
from ._files import complete_upload_file
1015
from ._paths import compute_path_size, delete_paths
1116
from ._simcore_s3 import deep_copy_files_from_project, export_data
@@ -16,6 +21,8 @@
1621
def setup_worker_tasks(app: Celery) -> None:
1722
register_celery_types()
1823

24+
register_pydantic_types(FileUploadCompletionBody, FileMetaData, FoldersBody)
25+
1926
with log_context(_logger, logging.INFO, msg="worker task registration"):
2027
register_task(app, export_data, dont_autoretry_for=(AccessRightError,))
2128
register_task(app, compute_path_size)

0 commit comments

Comments
 (0)