Skip to content

Commit 8e58365

Browse files
feat: add streaming
1 parent 834c250 commit 8e58365

File tree

17 files changed

+335
-45
lines changed

17 files changed

+335
-45
lines changed

packages/celery-library/src/celery_library/backends/redis.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import contextlib
22
import logging
3+
from collections.abc import AsyncIterator
34
from datetime import timedelta
45
from typing import TYPE_CHECKING, Final
56

67
from models_library.progress_bar import ProgressReport
7-
from pydantic import ValidationError
8+
from pydantic import TypeAdapter, ValidationError
89
from servicelib.celery.models import (
910
Task,
11+
TaskEvent,
1012
TaskFilter,
1113
TaskID,
1214
TaskInfoStore,
@@ -16,6 +18,7 @@
1618
from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types
1719

1820
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
21+
_CELERY_TASK_STREAM_PREFIX: Final[str] = "celery-task-stream-"
1922
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
2023
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000
2124
_CELERY_TASK_METADATA_KEY: Final[str] = "metadata"
@@ -24,10 +27,14 @@
2427
_logger = logging.getLogger(__name__)
2528

2629

27-
def _build_key(task_id: TaskID) -> str:
30+
def _build_info_key(task_id: TaskID) -> str:
2831
return _CELERY_TASK_INFO_PREFIX + task_id
2932

3033

34+
def _build_stream_key(task_id: TaskID) -> str:
35+
return _CELERY_TASK_STREAM_PREFIX + task_id
36+
37+
3138
class RedisTaskInfoStore:
3239
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
3340
self._redis_client_sdk = redis_client_sdk
@@ -38,7 +45,7 @@ async def create_task(
3845
task_metadata: TaskMetadata,
3946
expiry: timedelta,
4047
) -> None:
41-
task_key = _build_key(task_id)
48+
task_key = _build_info_key(task_id)
4249
await handle_redis_returns_union_types(
4350
self._redis_client_sdk.redis.hset(
4451
name=task_key,
@@ -54,7 +61,7 @@ async def create_task(
5461
async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
5562
raw_result = await handle_redis_returns_union_types(
5663
self._redis_client_sdk.redis.hget(
57-
_build_key(task_id), _CELERY_TASK_METADATA_KEY
64+
_build_info_key(task_id), _CELERY_TASK_METADATA_KEY
5865
)
5966
)
6067
if not raw_result:
@@ -71,7 +78,7 @@ async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
7178
async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
7279
raw_result = await handle_redis_returns_union_types(
7380
self._redis_client_sdk.redis.hget(
74-
_build_key(task_id), _CELERY_TASK_PROGRESS_KEY
81+
_build_info_key(task_id), _CELERY_TASK_PROGRESS_KEY
7582
)
7683
)
7784
if not raw_result:
@@ -123,22 +130,54 @@ async def list_tasks(self, task_filter: TaskFilter) -> list[Task]:
123130
return tasks
124131

125132
async def remove_task(self, task_id: TaskID) -> None:
126-
await self._redis_client_sdk.redis.delete(_build_key(task_id))
133+
await self._redis_client_sdk.redis.delete(_build_info_key(task_id))
127134

128135
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
129136
await handle_redis_returns_union_types(
130137
self._redis_client_sdk.redis.hset(
131-
name=_build_key(task_id),
138+
name=_build_info_key(task_id),
132139
key=_CELERY_TASK_PROGRESS_KEY,
133140
value=report.model_dump_json(),
134141
)
135142
)
136143

137144
async def task_exists(self, task_id: TaskID) -> bool:
138-
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
145+
n = await self._redis_client_sdk.redis.exists(_build_info_key(task_id))
139146
assert isinstance(n, int) # nosec
140147
return n > 0
141148

149+
async def publish_task_event(self, task_id: TaskID, event: TaskEvent) -> None:
150+
await self._redis_client_sdk.redis.xadd(
151+
_build_stream_key(task_id),
152+
{"event": event.model_dump_json()},
153+
)
154+
155+
async def consume_task_events(
156+
self, task_id: TaskID, last_id: str | None = None
157+
) -> AsyncIterator[TaskEvent]:
158+
stream_key = _build_stream_key(task_id)
159+
_logger.exception("Last id: %s", last_id)
160+
while True:
161+
messages = await self._redis_client_sdk.redis.xread(
162+
{stream_key: last_id or "0-0"}, block=5000, count=10
163+
)
164+
if not messages:
165+
continue
166+
for _, events in messages:
167+
for msg_id, data in events:
168+
last_id = msg_id
169+
170+
raw_event = data.get("event")
171+
if raw_event is None:
172+
continue
173+
174+
try:
175+
event = TypeAdapter(TaskEvent).validate_json(raw_event)
176+
event.event_id = msg_id
177+
yield event
178+
except ValidationError:
179+
continue
180+
142181

143182
if TYPE_CHECKING:
144183
_: type[TaskInfoStore] = RedisTaskInfoStore

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from collections.abc import AsyncIterator
23
from dataclasses import dataclass
34
from typing import TYPE_CHECKING, Any
45
from uuid import uuid4
@@ -10,6 +11,7 @@
1011
from servicelib.celery.models import (
1112
TASK_DONE_STATES,
1213
Task,
14+
TaskEvent,
1315
TaskFilter,
1416
TaskID,
1517
TaskInfoStore,
@@ -183,6 +185,21 @@ async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> No
183185
report=report,
184186
)
185187

188+
async def publish_task_event(self, task_id: TaskID, event: TaskEvent) -> None:
189+
await self._task_info_store.publish_task_event(task_id, event)
190+
191+
async def consume_task_events(
192+
self,
193+
task_filter: TaskFilter,
194+
task_uuid: TaskUUID,
195+
last_id: str,
196+
) -> AsyncIterator[TaskEvent]:
197+
task_id = task_filter.create_task_id(task_uuid=task_uuid)
198+
async for event in self._task_info_store.consume_task_events(
199+
task_id=task_id, last_id=last_id
200+
):
201+
yield event
202+
186203

187204
if TYPE_CHECKING:
188205
_: type[TaskManager] = CeleryTaskManager

packages/models-library/src/models_library/api_schemas_long_running_tasks/tasks.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import urllib.parse
22
from datetime import datetime
3-
from typing import Any
3+
from typing import Any, Self
44

55
from common_library.exclude import Unset
66
from pydantic import BaseModel, ConfigDict, model_validator
@@ -46,5 +46,17 @@ def try_populate_task_name_from_task_id(self) -> "TaskBase":
4646

4747
class TaskGet(TaskBase):
4848
status_href: str
49-
result_href: str
5049
abort_href: str
50+
result_href: str | None = (
51+
None # Path to get the result of the task in content-type application/json
52+
)
53+
result_stream_href: str | None = (
54+
None # Path to get the result of the task in content-type text/event-stream
55+
)
56+
57+
@model_validator(mode="after")
58+
def _validate_result_hrefs(self) -> Self:
59+
if not (self.result_href or self.result_stream_href):
60+
msg = "Either result_href or result_stream_href must be set"
61+
raise ValueError(msg)
62+
return self

packages/service-library/src/servicelib/aiohttp/rest_responses.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ def create_data_response(data: Any, *, status: int = HTTP_200_OK) -> web.Respons
3636
return web.json_response(enveloped_payload, dumps=json_dumps, status=status)
3737

3838

39+
def create_event_stream_response(event_generator: Any) -> web.Response:
40+
return web.Response(
41+
body=event_generator(),
42+
status=HTTP_200_OK,
43+
reason=get_code_description(HTTP_200_OK),
44+
content_type="text/event-stream",
45+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
46+
)
47+
48+
3949
MAX_STATUS_MESSAGE_LENGTH: Final[int] = 100
4050

4151

packages/service-library/src/servicelib/celery/models.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import datetime
2+
from collections.abc import AsyncIterator
23
from enum import StrEnum
3-
from typing import Annotated, Any, Final, Protocol, Self, TypeAlias, TypeVar
4+
from typing import Annotated, Any, Final, Literal, Protocol, Self, TypeAlias, TypeVar
45
from uuid import UUID
56

67
from models_library.progress_bar import ProgressReport
7-
from pydantic import BaseModel, ConfigDict, StringConstraints, model_validator
8+
from pydantic import BaseModel, ConfigDict, Field, StringConstraints, model_validator
89
from pydantic.config import JsonDict
910

1011
ModelType = TypeVar("ModelType", bound=BaseModel)
@@ -126,6 +127,23 @@ class TaskMetadata(BaseModel):
126127
queue: TasksQueue = TasksQueue.DEFAULT
127128

128129

130+
class TaskDataEvent(BaseModel):
131+
type: Literal["data"] = "data"
132+
event_id: str | None = None
133+
data: Any
134+
135+
136+
class TaskStatusEvent(BaseModel):
137+
type: Literal["status"] = "status"
138+
event_id: str | None = None
139+
data: Literal["done", "error"]
140+
141+
142+
TaskEvent: TypeAlias = Annotated[
143+
TaskDataEvent | TaskStatusEvent, Field(discriminator="type")
144+
]
145+
146+
129147
class Task(BaseModel):
130148
uuid: TaskUUID
131149
metadata: TaskMetadata
@@ -185,9 +203,19 @@ async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: ...
185203
async def remove_task(self, task_id: TaskID) -> None: ...
186204

187205
async def set_task_progress(
188-
self, task_id: TaskID, report: ProgressReport
206+
self,
207+
task_id: TaskID,
208+
report: ProgressReport,
189209
) -> None: ...
190210

211+
async def publish_task_event(self, task_id: TaskID, event: TaskEvent) -> None: ...
212+
213+
def consume_task_events(
214+
self,
215+
task_id: TaskID,
216+
last_id: str,
217+
) -> AsyncIterator[TaskEvent]: ...
218+
191219

192220
class TaskStatus(BaseModel):
193221
task_uuid: TaskUUID

packages/service-library/src/servicelib/celery/task_manager.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from collections.abc import AsyncIterator
12
from typing import Any, Protocol, runtime_checkable
23

34
from models_library.progress_bar import ProgressReport
45

56
from ..celery.models import (
67
Task,
8+
TaskEvent,
79
TaskFilter,
810
TaskID,
911
TaskMetadata,
@@ -22,8 +24,6 @@ async def cancel_task(
2224
self, task_filter: TaskFilter, task_uuid: TaskUUID
2325
) -> None: ...
2426

25-
async def task_exists(self, task_id: TaskID) -> bool: ...
26-
2727
async def get_task_result(
2828
self, task_filter: TaskFilter, task_uuid: TaskUUID
2929
) -> Any: ...
@@ -37,3 +37,20 @@ async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: ...
3737
async def set_task_progress(
3838
self, task_id: TaskID, report: ProgressReport
3939
) -> None: ...
40+
41+
async def task_exists(self, task_id: TaskID) -> bool: ...
42+
43+
# Events
44+
45+
async def publish_task_event(
46+
self,
47+
task_id: TaskID,
48+
event: TaskEvent,
49+
) -> None: ...
50+
51+
def consume_task_events(
52+
self,
53+
task_filter: TaskFilter,
54+
task_uuid: TaskUUID,
55+
last_id: str | None = None,
56+
) -> AsyncIterator[TaskEvent]: ...

services/docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,7 @@ services:
11581158
WEBSERVER_ACTIVITY: "null"
11591159
WEBSERVER_ANNOUNCEMENTS: 0
11601160
WEBSERVER_CATALOG: "null"
1161+
WEBSERVER_CELERY: "null"
11611162
WEBSERVER_DB_LISTENER: 0
11621163
WEBSERVER_DIRECTOR_V2: "null"
11631164
WEBSERVER_EMAIL: "null"

services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from models_library.projects_nodes_io import StorageFileID
1818
from models_library.users import UserID
1919
from pydantic import TypeAdapter
20-
from servicelib.celery.models import TaskID
20+
from servicelib.celery.models import TaskDataEvent, TaskID, TaskStatusEvent
2121
from servicelib.logging_utils import log_context
2222
from servicelib.progress_bar import ProgressBarData
2323

@@ -139,30 +139,42 @@ async def search(
139139
user_id: UserID,
140140
project_id: ProjectID | None,
141141
name_pattern: str,
142-
) -> list[SearchResult]:
142+
) -> None:
143143
with log_context(
144144
_logger,
145145
logging.INFO,
146146
f"'{task_id}' search file {name_pattern=}",
147147
):
148-
dsm = get_dsm_provider(get_app_server(task.app).app).get(
148+
app_server = get_app_server(task.app)
149+
dsm = get_dsm_provider(app_server.app).get(
149150
SimcoreS3DataManager.get_location_id()
150151
)
151152

152153
assert isinstance(dsm, SimcoreS3DataManager) # nosec
153154

154-
return [
155-
SearchResult(
156-
name=item.file_name,
157-
project_id=item.project_id,
158-
created_at=item.created_at,
159-
modified_at=item.last_modified,
160-
is_directory=item.is_directory,
161-
)
162-
async for page in dsm.search(
163-
user_id=user_id,
164-
project_id=project_id,
165-
name_pattern=name_pattern,
155+
async for page in dsm.search(
156+
user_id=user_id,
157+
project_id=project_id,
158+
name_pattern=name_pattern,
159+
):
160+
data = [
161+
SearchResult(
162+
name=item.file_name,
163+
project_id=item.project_id,
164+
created_at=item.created_at,
165+
modified_at=item.last_modified,
166+
is_directory=item.is_directory,
167+
)
168+
for item in page
169+
]
170+
171+
await app_server.task_manager.publish_task_event(
172+
task_id,
173+
TaskDataEvent(
174+
data=TypeAdapter(list[SearchResult]).validate_python(data)
175+
),
166176
)
167-
for item in page
168-
]
177+
178+
await app_server.task_manager.publish_task_event(
179+
task_id, TaskStatusEvent(data="done")
180+
)

services/web/server/requirements/ci.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
--requirement _tools.txt
1515

1616
# installs this repo's packages
17+
simcore-celery-library @ ../../../packages/celery-library
1718
simcore-common-library @ ../../../packages/common-library
1819
simcore-models-library @ ../../../packages/models-library
1920
simcore-notifications-library @ ../../../packages/notifications-library/

services/web/server/requirements/dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
--requirement _tools.txt
1313

1414
# installs this repo's packages
15+
--editable ../../../packages/celery-library/
1516
--editable ../../../packages/common-library/
1617
--editable ../../../packages/models-library/
1718
--editable ../../../packages/notifications-library/

0 commit comments

Comments
 (0)