Skip to content

Commit e1071e8

Browse files
committed
improve
1 parent adbfc1d commit e1071e8

File tree

1 file changed

+40
-6
lines changed
  • services/storage/src/simcore_service_storage/modules/celery

1 file changed

+40
-6
lines changed

services/storage/src/simcore_service_storage/modules/celery/_task.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import traceback
44
from collections.abc import Callable, Coroutine
55
from functools import wraps
6-
from typing import Any, ParamSpec, TypeVar
6+
from typing import Any, Concatenate, ParamSpec, TypeAlias, TypeVar, overload
77

88
from celery import ( # type: ignore[import-untyped]
99
Celery,
@@ -52,24 +52,58 @@ def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any:
5252
P = ParamSpec("P")
5353
R = TypeVar("R")
5454

55+
TaskId: TypeAlias = str
56+
5557

5658
def _async_task_wrapper(
5759
app: Celery,
58-
) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, R]]:
59-
def decorator(coro: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, R]:
60+
) -> Callable[
61+
[Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]]],
62+
Callable[Concatenate[Task, P], R],
63+
]:
64+
def decorator(
65+
coro: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]],
66+
) -> Callable[Concatenate[Task, P], R]:
6067
@wraps(coro)
61-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
68+
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
6269
fastapi_app = get_fastapi_app(app)
70+
_logger.debug("BEFORE task id: %s", task.request.id)
71+
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
72+
assert task.request.id is not None # nosec
6373
return asyncio.run_coroutine_threadsafe(
64-
coro(*args, **kwargs), get_event_loop(fastapi_app)
74+
coro(task, task.request.id, *args, **kwargs),
75+
get_event_loop(fastapi_app),
6576
).result()
6677

6778
return wrapper
6879

6980
return decorator
7081

7182

72-
def define_task(app: Celery, fn: Callable, task_name: str | None = None):
83+
@overload
84+
def define_task(
85+
app: Celery,
86+
fn: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]],
87+
task_name: str | None = None,
88+
) -> None: ...
89+
90+
91+
@overload
92+
def define_task(
93+
app: Celery,
94+
fn: Callable[Concatenate[Task, P], R],
95+
task_name: str | None = None,
96+
) -> None: ...
97+
98+
99+
def define_task(
100+
app: Celery,
101+
fn: (
102+
Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]]
103+
| Callable[Concatenate[Task, P], R]
104+
),
105+
task_name: str | None = None,
106+
):
73107
wrapped_fn = error_handling(fn)
74108
if asyncio.iscoroutinefunction(fn):
75109
wrapped_fn = _async_task_wrapper(app)(fn)

0 commit comments

Comments
 (0)