|
3 | 3 | import traceback |
4 | 4 | from collections.abc import Callable, Coroutine |
5 | 5 | from functools import wraps |
6 | | -from typing import Any, ParamSpec, TypeVar |
| 6 | +from typing import Any, Concatenate, ParamSpec, TypeAlias, TypeVar, overload |
7 | 7 |
|
8 | 8 | from celery import ( # type: ignore[import-untyped] |
9 | 9 | Celery, |
@@ -52,24 +52,58 @@ def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any: |
52 | 52 | P = ParamSpec("P") |
53 | 53 | R = TypeVar("R") |
54 | 54 |
|
| 55 | +TaskId: TypeAlias = str |
| 56 | + |
55 | 57 |
|
56 | 58 | def _async_task_wrapper( |
57 | 59 | app: Celery, |
58 | | -) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, R]]: |
59 | | - def decorator(coro: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, R]: |
| 60 | +) -> Callable[ |
| 61 | + [Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]]], |
| 62 | + Callable[Concatenate[Task, P], R], |
| 63 | +]: |
| 64 | + def decorator( |
| 65 | + coro: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]], |
| 66 | + ) -> Callable[Concatenate[Task, P], R]: |
60 | 67 | @wraps(coro) |
61 | | - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: |
| 68 | + def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R: |
62 | 69 | fastapi_app = get_fastapi_app(app) |
| 70 | + _logger.debug("BEFORE task id: %s", task.request.id) |
| 71 | + # NOTE: task.request is a thread local object, so we need to pass the id explicitly |
| 72 | + assert task.request.id is not None # nosec |
63 | 73 | return asyncio.run_coroutine_threadsafe( |
64 | | - coro(*args, **kwargs), get_event_loop(fastapi_app) |
| 74 | + coro(task, task.request.id, *args, **kwargs), |
| 75 | + get_event_loop(fastapi_app), |
65 | 76 | ).result() |
66 | 77 |
|
67 | 78 | return wrapper |
68 | 79 |
|
69 | 80 | return decorator |
70 | 81 |
|
71 | 82 |
|
72 | | -def define_task(app: Celery, fn: Callable, task_name: str | None = None): |
| 83 | +@overload |
| 84 | +def define_task( |
| 85 | + app: Celery, |
| 86 | + fn: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]], |
| 87 | + task_name: str | None = None, |
| 88 | +) -> None: ... |
| 89 | + |
| 90 | + |
| 91 | +@overload |
| 92 | +def define_task( |
| 93 | + app: Celery, |
| 94 | + fn: Callable[Concatenate[Task, P], R], |
| 95 | + task_name: str | None = None, |
| 96 | +) -> None: ... |
| 97 | + |
| 98 | + |
| 99 | +def define_task( |
| 100 | + app: Celery, |
| 101 | + fn: ( |
| 102 | + Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]] |
| 103 | + | Callable[Concatenate[Task, P], R] |
| 104 | + ), |
| 105 | + task_name: str | None = None, |
| 106 | +): |
73 | 107 | wrapped_fn = error_handling(fn) |
74 | 108 | if asyncio.iscoroutinefunction(fn): |
75 | 109 | wrapped_fn = _async_task_wrapper(app)(fn) |
|
0 commit comments