Skip to content

Commit 19b1ed8

Browse files
add async jobs stream
1 parent 2b2f64b commit 19b1ed8

File tree

7 files changed

+259
-18
lines changed

7 files changed

+259
-18
lines changed

packages/celery-library/src/celery_library/backends/redis.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,62 @@
11
import contextlib
22
import logging
3+
from collections.abc import AsyncIterator
4+
from dataclasses import dataclass
35
from datetime import timedelta
46
from typing import TYPE_CHECKING, Final
57

68
from models_library.progress_bar import ProgressReport
7-
from pydantic import ValidationError
9+
from pydantic import TypeAdapter, ValidationError
810
from servicelib.celery.models import (
911
WILDCARD,
1012
ExecutionMetadata,
1113
OwnerMetadata,
1214
Task,
15+
TaskEvent,
16+
TaskEventID,
1317
TaskID,
1418
TaskInfoStore,
19+
TaskStatusEvent,
20+
TaskStatusValue,
1521
)
1622
from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types
1723

1824
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
25+
_CELERY_TASK_STREAM_PREFIX: Final[str] = "celery-task-stream-"
1926
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
2027
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000
2128
_CELERY_TASK_METADATA_KEY: Final[str] = "metadata"
2229
_CELERY_TASK_PROGRESS_KEY: Final[str] = "progress"
2330

31+
_CELERY_TASK_STREAM_DEFAULT_ID: Final[str] = "0-0"
32+
_CELERY_TASK_STREAM_BLOCK_TIMEOUT: Final[int] = 3 * 1000 # milliseconds
33+
_CELERY_TASK_STREAM_COUNT: Final[int] = 10
34+
_CELERY_TASK_STREAM_EXPIRE_DEFAULT: Final[timedelta] = timedelta(minutes=5)
35+
_CELERY_TASK_STREAM_MAXLEN: Final[int] = 100_000
36+
37+
2438
_logger = logging.getLogger(__name__)
2539

2640

27-
def _build_key(task_id: TaskID) -> str:
41+
def _build_info_key(task_id: TaskID) -> str:
2842
return _CELERY_TASK_INFO_PREFIX + task_id
2943

3044

45+
def _build_stream_key(task_id: TaskID) -> str:
46+
return _CELERY_TASK_STREAM_PREFIX + task_id
47+
48+
49+
@dataclass
3150
class RedisTaskInfoStore:
32-
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
33-
self._redis_client_sdk = redis_client_sdk
51+
_redis_client_sdk: RedisClientSDK
3452

3553
async def create_task(
3654
self,
3755
task_id: TaskID,
3856
execution_metadata: ExecutionMetadata,
3957
expiry: timedelta,
4058
) -> None:
41-
task_key = _build_key(task_id)
59+
task_key = _build_info_key(task_id)
4260
await handle_redis_returns_union_types(
4361
self._redis_client_sdk.redis.hset(
4462
name=task_key,
@@ -51,10 +69,26 @@ async def create_task(
5169
expiry,
5270
)
5371

72+
if execution_metadata.streamed_result:
73+
stream_key = _build_stream_key(task_id)
74+
await self._redis_client_sdk.redis.xadd(
75+
stream_key,
76+
{
77+
"event": TaskStatusEvent(
78+
data=TaskStatusValue.CREATED
79+
).model_dump_json()
80+
},
81+
maxlen=_CELERY_TASK_STREAM_MAXLEN,
82+
approximate=True,
83+
)
84+
await self._redis_client_sdk.redis.expire(
85+
stream_key, _CELERY_TASK_STREAM_EXPIRE_DEFAULT
86+
)
87+
5488
async def get_task_metadata(self, task_id: TaskID) -> ExecutionMetadata | None:
5589
raw_result = await handle_redis_returns_union_types(
5690
self._redis_client_sdk.redis.hget(
57-
_build_key(task_id), _CELERY_TASK_METADATA_KEY
91+
_build_info_key(task_id), _CELERY_TASK_METADATA_KEY
5892
)
5993
)
6094
if not raw_result:
@@ -71,7 +105,7 @@ async def get_task_metadata(self, task_id: TaskID) -> ExecutionMetadata | None:
71105
async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
72106
raw_result = await handle_redis_returns_union_types(
73107
self._redis_client_sdk.redis.hget(
74-
_build_key(task_id), _CELERY_TASK_PROGRESS_KEY
108+
_build_info_key(task_id), _CELERY_TASK_PROGRESS_KEY
75109
)
76110
)
77111
if not raw_result:
@@ -123,22 +157,56 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]:
123157
return tasks
124158

125159
async def remove_task(self, task_id: TaskID) -> None:
126-
await self._redis_client_sdk.redis.delete(_build_key(task_id))
160+
await self._redis_client_sdk.redis.delete(_build_info_key(task_id))
161+
await self._redis_client_sdk.redis.delete(_build_stream_key(task_id))
127162

128163
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
129164
await handle_redis_returns_union_types(
130165
self._redis_client_sdk.redis.hset(
131-
name=_build_key(task_id),
166+
name=_build_info_key(task_id),
132167
key=_CELERY_TASK_PROGRESS_KEY,
133168
value=report.model_dump_json(),
134169
)
135170
)
136171

137172
async def task_exists(self, task_id: TaskID) -> bool:
138-
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
173+
n = await self._redis_client_sdk.redis.exists(_build_info_key(task_id))
139174
assert isinstance(n, int) # nosec
140175
return n > 0
141176

177+
async def publish_task_event(self, task_id: TaskID, event: TaskEvent) -> None:
178+
stream_key = _build_stream_key(task_id)
179+
await self._redis_client_sdk.redis.xadd(
180+
stream_key,
181+
{"event": event.model_dump_json()},
182+
)
183+
184+
async def consume_task_events(
185+
self, task_id: TaskID, last_id: str | None = None
186+
) -> AsyncIterator[tuple[TaskEventID, TaskEvent]]:
187+
stream_key = _build_stream_key(task_id)
188+
while True:
189+
messages = await self._redis_client_sdk.redis.xread(
190+
{stream_key: last_id or _CELERY_TASK_STREAM_DEFAULT_ID},
191+
block=_CELERY_TASK_STREAM_BLOCK_TIMEOUT,
192+
count=_CELERY_TASK_STREAM_COUNT,
193+
)
194+
if not messages:
195+
continue
196+
for _, events in messages:
197+
for msg_id, data in events:
198+
raw_event = data.get("event")
199+
if raw_event is None:
200+
continue
201+
202+
try:
203+
event: TaskEvent = TypeAdapter(TaskEvent).validate_json(
204+
raw_event
205+
)
206+
yield msg_id, event
207+
except ValidationError:
208+
continue
209+
142210

143211
if TYPE_CHECKING:
144212
_: type[TaskInfoStore] = RedisTaskInfoStore

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from collections.abc import AsyncIterator
23
from dataclasses import dataclass
34
from typing import TYPE_CHECKING, Any
45
from uuid import uuid4
@@ -12,6 +13,8 @@
1213
ExecutionMetadata,
1314
OwnerMetadata,
1415
Task,
16+
TaskEvent,
17+
TaskEventID,
1518
TaskID,
1619
TaskInfoStore,
1720
TaskState,
@@ -22,7 +25,7 @@
2225
from servicelib.logging_utils import log_context
2326
from settings_library.celery import CelerySettings
2427

25-
from .errors import TaskNotFoundError, TaskSubmissionError
28+
from .errors import TaskNotFoundError, TaskSubmissionError, handle_celery_errors
2629

2730
_logger = logging.getLogger(__name__)
2831

@@ -37,6 +40,7 @@ class CeleryTaskManager:
3740
_celery_settings: CelerySettings
3841
_task_info_store: TaskInfoStore
3942

43+
@handle_celery_errors
4044
async def submit_task(
4145
self,
4246
execution_metadata: ExecutionMetadata,
@@ -85,6 +89,7 @@ async def submit_task(
8589

8690
return task_uuid
8791

92+
@handle_celery_errors
8893
async def cancel_task(
8994
self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID
9095
) -> None:
@@ -107,6 +112,7 @@ async def task_exists(self, task_id: TaskID) -> bool:
107112
def _forget_task(self, task_id: TaskID) -> None:
108113
self._celery_app.AsyncResult(task_id).forget()
109114

115+
@handle_celery_errors
110116
async def get_task_result(
111117
self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID
112118
) -> Any:
@@ -150,6 +156,7 @@ async def _get_task_progress_report(
150156
def _get_task_celery_state(self, task_id: TaskID) -> TaskState:
151157
return TaskState(self._celery_app.AsyncResult(task_id).state)
152158

159+
@handle_celery_errors
153160
async def get_task_status(
154161
self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID
155162
) -> TaskStatus:
@@ -171,6 +178,7 @@ async def get_task_status(
171178
),
172179
)
173180

181+
@handle_celery_errors
174182
async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]:
175183
with log_context(
176184
_logger,
@@ -179,12 +187,30 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]:
179187
):
180188
return await self._task_info_store.list_tasks(owner_metadata)
181189

190+
@handle_celery_errors
182191
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
183192
await self._task_info_store.set_task_progress(
184193
task_id=task_id,
185194
report=report,
186195
)
187196

197+
@handle_celery_errors
198+
async def publish_task_event(self, task_id: TaskID, event: TaskEvent) -> None:
199+
await self._task_info_store.publish_task_event(task_id, event)
200+
201+
@handle_celery_errors
202+
async def consume_task_events(
203+
self,
204+
owner_metadata: OwnerMetadata,
205+
task_uuid: TaskUUID,
206+
last_id: str | None = None,
207+
) -> AsyncIterator[tuple[TaskEventID, TaskEvent]]:
208+
task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid)
209+
async for event in self._task_info_store.consume_task_events(
210+
task_id=task_id, last_id=last_id
211+
):
212+
yield event
213+
188214

189215
if TYPE_CHECKING:
190216
_: type[TaskManager] = CeleryTaskManager

packages/service-library/src/servicelib/celery/models.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
2+
from collections.abc import AsyncIterator
23
from enum import StrEnum
3-
from typing import Annotated, Final, Literal, Protocol, Self, TypeAlias, TypeVar
4+
from typing import Annotated, Any, Final, Literal, Protocol, Self, TypeAlias, TypeVar
45
from uuid import UUID
56

67
import orjson
@@ -133,9 +134,43 @@ class TasksQueue(StrEnum):
133134
class ExecutionMetadata(BaseModel):
134135
name: TaskName
135136
ephemeral: bool = True
137+
streamed_result: bool = False
136138
queue: TasksQueue = TasksQueue.DEFAULT
137139

138140

141+
TaskEventID: TypeAlias = str
142+
143+
144+
class TaskEventType(StrEnum):
145+
DATA = "data"
146+
STATUS = "status"
147+
148+
149+
class TaskStatusValue(StrEnum):
150+
CREATED = "created"
151+
SUCCESS = "success"
152+
ERROR = "error"
153+
154+
155+
class TaskDataEvent(BaseModel):
156+
type: Literal[TaskEventType.DATA] = TaskEventType.DATA
157+
data: Any
158+
159+
160+
class TaskStatusEvent(BaseModel):
161+
type: Literal[TaskEventType.STATUS] = TaskEventType.STATUS
162+
data: TaskStatusValue
163+
164+
def is_done(self):
165+
return self.data in (TaskStatusValue.SUCCESS, TaskStatusValue.ERROR)
166+
167+
168+
TaskEvent = Annotated[
169+
TaskDataEvent | TaskStatusEvent,
170+
Field(discriminator="type"),
171+
]
172+
173+
139174
class Task(BaseModel):
140175
uuid: TaskUUID
141176
metadata: ExecutionMetadata
@@ -195,9 +230,19 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: ...
195230
async def remove_task(self, task_id: TaskID) -> None: ...
196231

197232
async def set_task_progress(
198-
self, task_id: TaskID, report: ProgressReport
233+
self,
234+
task_id: TaskID,
235+
report: ProgressReport,
199236
) -> None: ...
200237

238+
async def publish_task_event(self, task_id: TaskID, event: TaskEvent) -> None: ...
239+
240+
def consume_task_events(
241+
self,
242+
task_id: TaskID,
243+
last_id: str | None = None,
244+
) -> AsyncIterator[tuple[TaskEventID, TaskEvent]]: ...
245+
201246

202247
class TaskStatus(BaseModel):
203248
task_uuid: TaskUUID

packages/service-library/src/servicelib/celery/task_manager.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
from collections.abc import AsyncIterator
12
from typing import Any, Protocol, runtime_checkable
23

34
from models_library.progress_bar import ProgressReport
45

5-
from ..celery.models import (
6+
from .models import (
67
ExecutionMetadata,
78
OwnerMetadata,
89
Task,
10+
TaskEvent,
11+
TaskEventID,
912
TaskID,
1013
TaskStatus,
1114
TaskUUID,
@@ -26,8 +29,6 @@ async def cancel_task(
2629
self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID
2730
) -> None: ...
2831

29-
async def task_exists(self, task_id: TaskID) -> bool: ...
30-
3132
async def get_task_result(
3233
self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID
3334
) -> Any: ...
@@ -41,3 +42,20 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: ...
4142
async def set_task_progress(
4243
self, task_id: TaskID, report: ProgressReport
4344
) -> None: ...
45+
46+
async def task_exists(self, task_id: TaskID) -> bool: ...
47+
48+
# Events
49+
50+
async def publish_task_event(
51+
self,
52+
task_id: TaskID,
53+
event: TaskEvent,
54+
) -> None: ...
55+
56+
def consume_task_events(
57+
self,
58+
owner_metadata: OwnerMetadata,
59+
task_uuid: TaskUUID,
60+
last_id: str | None = None,
61+
) -> AsyncIterator[tuple[TaskEventID, TaskEvent]]: ...

packages/service-library/src/servicelib/sse/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)