Skip to content

Commit 5a86ddb

Browse files
🎨 TaskID -> TaskKey (#8409)
1 parent 3d80890 commit 5a86ddb

File tree

22 files changed

+260
-219
lines changed

22 files changed

+260
-219
lines changed

packages/celery-library/src/celery_library/backends/redis.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
ExecutionMetadata,
1111
OwnerMetadata,
1212
Task,
13-
TaskID,
1413
TaskInfoStore,
14+
TaskKey,
1515
)
1616
from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types
1717

@@ -24,8 +24,8 @@
2424
_logger = logging.getLogger(__name__)
2525

2626

27-
def _build_key(task_id: TaskID) -> str:
28-
return _CELERY_TASK_INFO_PREFIX + task_id
27+
def _build_key(task_key: TaskKey) -> str:
28+
return _CELERY_TASK_INFO_PREFIX + task_key
2929

3030

3131
class RedisTaskInfoStore:
@@ -34,27 +34,27 @@ def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
3434

3535
async def create_task(
3636
self,
37-
task_id: TaskID,
37+
task_key: TaskKey,
3838
execution_metadata: ExecutionMetadata,
3939
expiry: timedelta,
4040
) -> None:
41-
task_key = _build_key(task_id)
41+
redis_key = _build_key(task_key)
4242
await handle_redis_returns_union_types(
4343
self._redis_client_sdk.redis.hset(
44-
name=task_key,
44+
name=redis_key,
4545
key=_CELERY_TASK_METADATA_KEY,
4646
value=execution_metadata.model_dump_json(),
4747
)
4848
)
4949
await self._redis_client_sdk.redis.expire(
50-
task_key,
50+
redis_key,
5151
expiry,
5252
)
5353

54-
async def get_task_metadata(self, task_id: TaskID) -> ExecutionMetadata | None:
54+
async def get_task_metadata(self, task_key: TaskKey) -> ExecutionMetadata | None:
5555
raw_result = await handle_redis_returns_union_types(
5656
self._redis_client_sdk.redis.hget(
57-
_build_key(task_id), _CELERY_TASK_METADATA_KEY
57+
_build_key(task_key), _CELERY_TASK_METADATA_KEY
5858
)
5959
)
6060
if not raw_result:
@@ -64,14 +64,16 @@ async def get_task_metadata(self, task_id: TaskID) -> ExecutionMetadata | None:
6464
return ExecutionMetadata.model_validate_json(raw_result)
6565
except ValidationError as exc:
6666
_logger.debug(
67-
"Failed to deserialize task metadata for task %s: %s", task_id, f"{exc}"
67+
"Failed to deserialize task metadata for task %s: %s",
68+
task_key,
69+
f"{exc}",
6870
)
6971
return None
7072

71-
async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
73+
async def get_task_progress(self, task_key: TaskKey) -> ProgressReport | None:
7274
raw_result = await handle_redis_returns_union_types(
7375
self._redis_client_sdk.redis.hget(
74-
_build_key(task_id), _CELERY_TASK_PROGRESS_KEY
76+
_build_key(task_key), _CELERY_TASK_PROGRESS_KEY
7577
)
7678
)
7779
if not raw_result:
@@ -81,12 +83,14 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
8183
return ProgressReport.model_validate_json(raw_result)
8284
except ValidationError as exc:
8385
_logger.debug(
84-
"Failed to deserialize task progress for task %s: %s", task_id, f"{exc}"
86+
"Failed to deserialize task progress for task %s: %s",
87+
task_key,
88+
f"{exc}",
8589
)
8690
return None
8791

8892
async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]:
89-
search_key = _CELERY_TASK_INFO_PREFIX + owner_metadata.model_dump_task_id(
93+
search_key = _CELERY_TASK_INFO_PREFIX + owner_metadata.model_dump_task_key(
9094
task_uuid=WILDCARD
9195
)
9296

@@ -122,20 +126,22 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]:
122126

123127
return tasks
124128

125-
async def remove_task(self, task_id: TaskID) -> None:
126-
await self._redis_client_sdk.redis.delete(_build_key(task_id))
129+
async def remove_task(self, task_key: TaskKey) -> None:
130+
await self._redis_client_sdk.redis.delete(_build_key(task_key))
127131

128-
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
132+
async def set_task_progress(
133+
self, task_key: TaskKey, report: ProgressReport
134+
) -> None:
129135
await handle_redis_returns_union_types(
130136
self._redis_client_sdk.redis.hset(
131-
name=_build_key(task_id),
137+
name=_build_key(task_key),
132138
key=_CELERY_TASK_PROGRESS_KEY,
133139
value=report.model_dump_json(),
134140
)
135141
)
136142

137-
async def task_exists(self, task_id: TaskID) -> bool:
138-
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
143+
async def task_exists(self, task_key: TaskKey) -> bool:
144+
n = await self._redis_client_sdk.redis.exists(_build_key(task_key))
139145
assert isinstance(n, int) # nosec
140146
return n > 0
141147

packages/celery-library/src/celery_library/errors.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ def decode_celery_transferrable_error(error: TransferrableCeleryError) -> Except
2727

2828

2929
class TaskSubmissionError(OsparcErrorMixin, Exception):
30-
msg_template = (
31-
"Unable to submit task {task_name} with id '{task_id}' and params {task_params}"
32-
)
30+
msg_template = "Unable to submit task {task_name} with key '{task_key}' and params {task_params}"
3331

3432

3533
class TaskNotFoundError(OsparcErrorMixin, Exception):
36-
msg_template = "Task with id '{task_id}' was not found"
34+
msg_template = "Task with uuid '{task_uuid}' and owner_metadata '{owner_metadata}' was not found"

packages/celery-library/src/celery_library/task.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from celery.exceptions import Ignore # type: ignore[import-untyped]
1111
from common_library.async_tools import cancel_wait_task
1212
from pydantic import NonNegativeInt
13-
from servicelib.celery.models import TaskID
13+
from servicelib.celery.models import TaskKey
1414

1515
from .errors import encode_celery_transferrable_error
1616
from .utils import get_app_server
@@ -47,7 +47,7 @@ def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
4747
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
4848
assert task.request.id is not None # nosec
4949

50-
async def _run_task(task_id: TaskID) -> R:
50+
async def _run_task(task_key: TaskKey) -> R:
5151
try:
5252
async with asyncio.TaskGroup() as tg:
5353
async_io_task = tg.create_task(
@@ -57,7 +57,7 @@ async def _run_task(task_id: TaskID) -> R:
5757
async def _abort_monitor():
5858
while not async_io_task.done():
5959
if not await app_server.task_manager.task_exists(
60-
task_id
60+
task_key
6161
):
6262
await cancel_wait_task(
6363
async_io_task,
@@ -140,7 +140,7 @@ def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
140140
@overload
141141
def register_task(
142142
app: Celery,
143-
fn: Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]],
143+
fn: Callable[Concatenate[Task, TaskKey, P], Coroutine[Any, Any, R]],
144144
task_name: str | None = None,
145145
timeout: timedelta | None = _DEFAULT_TASK_TIMEOUT,
146146
max_retries: NonNegativeInt = _DEFAULT_MAX_RETRIES,
@@ -164,7 +164,7 @@ def register_task(
164164
def register_task( # type: ignore[misc]
165165
app: Celery,
166166
fn: (
167-
Callable[Concatenate[Task, TaskID, P], Coroutine[Any, Any, R]]
167+
Callable[Concatenate[Task, TaskKey, P], Coroutine[Any, Any, R]]
168168
| Callable[Concatenate[Task, P], R]
169169
),
170170
task_name: str | None = None,

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
ExecutionMetadata,
1313
OwnerMetadata,
1414
Task,
15-
TaskID,
1615
TaskInfoStore,
16+
TaskKey,
1717
TaskState,
1818
TaskStatus,
1919
TaskUUID,
@@ -50,7 +50,7 @@ async def submit_task(
5050
msg=f"Submit {execution_metadata.name=}: {owner_metadata=} {task_params=}",
5151
):
5252
task_uuid = uuid4()
53-
task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid)
53+
task_key = owner_metadata.model_dump_task_key(task_uuid=task_uuid)
5454

5555
expiry = (
5656
self._celery_settings.CELERY_EPHEMERAL_RESULT_EXPIRES
@@ -60,26 +60,26 @@ async def submit_task(
6060

6161
try:
6262
await self._task_info_store.create_task(
63-
task_id, execution_metadata, expiry=expiry
63+
task_key, execution_metadata, expiry=expiry
6464
)
6565
self._celery_app.send_task(
6666
execution_metadata.name,
67-
task_id=task_id,
68-
kwargs={"task_id": task_id} | task_params,
67+
task_id=task_key,
68+
kwargs={"task_key": task_key} | task_params,
6969
queue=execution_metadata.queue.value,
7070
)
7171
except CeleryError as exc:
7272
try:
73-
await self._task_info_store.remove_task(task_id)
73+
await self._task_info_store.remove_task(task_key)
7474
except CeleryError:
7575
_logger.warning(
7676
"Unable to cleanup task '%s' during error handling",
77-
task_id,
77+
task_key,
7878
exc_info=True,
7979
)
8080
raise TaskSubmissionError(
8181
task_name=execution_metadata.name,
82-
task_id=task_id,
82+
task_key=task_key,
8383
task_params=task_params,
8484
) from exc
8585

@@ -93,19 +93,21 @@ async def cancel_task(
9393
logging.DEBUG,
9494
msg=f"task cancellation: {owner_metadata=} {task_uuid=}",
9595
):
96-
task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid)
97-
if not await self.task_exists(task_id):
98-
raise TaskNotFoundError(task_id=task_id)
96+
task_key = owner_metadata.model_dump_task_key(task_uuid=task_uuid)
97+
if not await self.task_exists(task_key):
98+
raise TaskNotFoundError(
99+
task_uuid=task_uuid, owner_metadata=owner_metadata
100+
)
99101

100-
await self._task_info_store.remove_task(task_id)
101-
await self._forget_task(task_id)
102+
await self._task_info_store.remove_task(task_key)
103+
await self._forget_task(task_key)
102104

103-
async def task_exists(self, task_id: TaskID) -> bool:
104-
return await self._task_info_store.task_exists(task_id)
105+
async def task_exists(self, task_key: TaskKey) -> bool:
106+
return await self._task_info_store.task_exists(task_key)
105107

106108
@make_async()
107-
def _forget_task(self, task_id: TaskID) -> None:
108-
self._celery_app.AsyncResult(task_id).forget()
109+
def _forget_task(self, task_key: TaskKey) -> None:
110+
self._celery_app.AsyncResult(task_key).forget()
109111

110112
async def get_task_result(
111113
self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID
@@ -115,24 +117,26 @@ async def get_task_result(
115117
logging.DEBUG,
116118
msg=f"Get task result: {owner_metadata=} {task_uuid=}",
117119
):
118-
task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid)
119-
if not await self.task_exists(task_id):
120-
raise TaskNotFoundError(task_id=task_id)
120+
task_key = owner_metadata.model_dump_task_key(task_uuid=task_uuid)
121+
if not await self.task_exists(task_key):
122+
raise TaskNotFoundError(
123+
task_uuid=task_uuid, owner_metadata=owner_metadata
124+
)
121125

122-
async_result = self._celery_app.AsyncResult(task_id)
126+
async_result = self._celery_app.AsyncResult(task_key)
123127
result = async_result.result
124128
if async_result.ready():
125-
task_metadata = await self._task_info_store.get_task_metadata(task_id)
129+
task_metadata = await self._task_info_store.get_task_metadata(task_key)
126130
if task_metadata is not None and task_metadata.ephemeral:
127-
await self._task_info_store.remove_task(task_id)
128-
await self._forget_task(task_id)
131+
await self._task_info_store.remove_task(task_key)
132+
await self._forget_task(task_key)
129133
return result
130134

131135
async def _get_task_progress_report(
132-
self, task_id: TaskID, task_state: TaskState
136+
self, task_key: TaskKey, task_state: TaskState
133137
) -> ProgressReport:
134138
if task_state in (TaskState.STARTED, TaskState.RETRY):
135-
progress = await self._task_info_store.get_task_progress(task_id)
139+
progress = await self._task_info_store.get_task_progress(task_key)
136140
if progress is not None:
137141
return progress
138142

@@ -147,8 +151,8 @@ async def _get_task_progress_report(
147151
)
148152

149153
@make_async()
150-
def _get_task_celery_state(self, task_id: TaskID) -> TaskState:
151-
return TaskState(self._celery_app.AsyncResult(task_id).state)
154+
def _get_task_celery_state(self, task_key: TaskKey) -> TaskState:
155+
return TaskState(self._celery_app.AsyncResult(task_key).state)
152156

153157
async def get_task_status(
154158
self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID
@@ -158,16 +162,18 @@ async def get_task_status(
158162
logging.DEBUG,
159163
msg=f"Getting task status: {owner_metadata=} {task_uuid=}",
160164
):
161-
task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid)
162-
if not await self.task_exists(task_id):
163-
raise TaskNotFoundError(task_id=task_id)
165+
task_key = owner_metadata.model_dump_task_key(task_uuid=task_uuid)
166+
if not await self.task_exists(task_key):
167+
raise TaskNotFoundError(
168+
task_uuid=task_uuid, owner_metadata=owner_metadata
169+
)
164170

165-
task_state = await self._get_task_celery_state(task_id)
171+
task_state = await self._get_task_celery_state(task_key)
166172
return TaskStatus(
167173
task_uuid=task_uuid,
168174
task_state=task_state,
169175
progress_report=await self._get_task_progress_report(
170-
task_id, task_state
176+
task_key, task_state
171177
),
172178
)
173179

@@ -179,9 +185,11 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]:
179185
):
180186
return await self._task_info_store.list_tasks(owner_metadata)
181187

182-
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
188+
async def set_task_progress(
189+
self, task_key: TaskKey, report: ProgressReport
190+
) -> None:
183191
await self._task_info_store.set_task_progress(
184-
task_id=task_id,
192+
task_key=task_key,
185193
report=report,
186194
)
187195

packages/celery-library/tests/unit/test_async_jobs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from models_library.rabbitmq_basic_types import RPCNamespace
2727
from models_library.users import UserID
2828
from pydantic import TypeAdapter
29-
from servicelib.celery.models import ExecutionMetadata, OwnerMetadata, TaskID
29+
from servicelib.celery.models import ExecutionMetadata, OwnerMetadata, TaskKey
3030
from servicelib.celery.task_manager import TaskManager
3131
from servicelib.rabbitmq import RabbitMQRPCClient, RPCRouter
3232
from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs
@@ -121,15 +121,15 @@ async def _process_action(action: str, payload: Any) -> Any:
121121
return None
122122

123123

124-
def sync_job(task: Task, task_id: TaskID, action: Action, payload: Any) -> Any:
124+
def sync_job(task: Task, task_key: TaskKey, action: Action, payload: Any) -> Any:
125125
_ = task
126-
_ = task_id
126+
_ = task_key
127127
return asyncio.run(_process_action(action, payload))
128128

129129

130-
async def async_job(task: Task, task_id: TaskID, action: Action, payload: Any) -> Any:
130+
async def async_job(task: Task, task_key: TaskKey, action: Action, payload: Any) -> Any:
131131
_ = task
132-
_ = task_id
132+
_ = task_key
133133
return await _process_action(action, payload)
134134

135135

0 commit comments

Comments
 (0)