Skip to content

Commit 10d5952

Browse files
fix: refactor
1 parent 67cce39 commit 10d5952

File tree

5 files changed

+45
-20
lines changed

5 files changed

+45
-20
lines changed

packages/celery-library/src/celery_library/backends/redis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from servicelib.celery.models import (
1010
Task,
1111
TaskEvent,
12+
TaskEventID,
1213
TaskFilter,
1314
TaskID,
1415
TaskInfoStore,
@@ -161,7 +162,7 @@ async def publish_task_event(self, task_id: TaskID, event: TaskEvent) -> None:
161162

162163
async def consume_task_events(
163164
self, task_id: TaskID, last_id: str | None = None
164-
) -> AsyncIterator[TaskEvent]:
165+
) -> AsyncIterator[tuple[TaskEventID, TaskEvent]]:
165166
stream_key = _build_stream_key(task_id)
166167
while True:
167168
messages = await self._redis_client_sdk.redis.xread(
@@ -183,8 +184,7 @@ async def consume_task_events(
183184
event: TaskEvent = TypeAdapter(TaskEvent).validate_json(
184185
raw_event
185186
)
186-
event.event_id = msg_id
187-
yield event
187+
yield msg_id, event
188188
except ValidationError:
189189
continue
190190

packages/service-library/src/servicelib/celery/models.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,15 @@ class TaskMetadata(BaseModel):
127127
queue: TasksQueue = TasksQueue.DEFAULT
128128

129129

130-
class BaseTaskEvent(BaseModel):
131-
event_id: str | None = None
130+
TaskEventID: TypeAlias = str
132131

133132

134-
class TaskDataEvent(BaseTaskEvent):
133+
class TaskDataEvent(BaseModel):
135134
type: Literal["data"] = "data"
136135
data: Any
137136

138137

139-
class TaskStatusEvent(BaseTaskEvent):
138+
class TaskStatusEvent(BaseModel):
140139
type: Literal["status"] = "status"
141140
data: Literal["done", "error"]
142141

@@ -217,7 +216,7 @@ def consume_task_events(
217216
self,
218217
task_id: TaskID,
219218
last_id: str | None = None,
220-
) -> AsyncIterator[TaskEvent]: ...
219+
) -> AsyncIterator[tuple[TaskEventID, TaskEvent]]: ...
221220

222221

223222
class TaskStatus(BaseModel):

packages/service-library/src/servicelib/celery/task_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..celery.models import (
77
Task,
88
TaskEvent,
9+
TaskEventID,
910
TaskFilter,
1011
TaskID,
1112
TaskMetadata,
@@ -53,4 +54,4 @@ def consume_task_events(
5354
task_filter: TaskFilter,
5455
task_uuid: TaskUUID,
5556
last_id: str | None = None,
56-
) -> AsyncIterator[TaskEvent]: ...
57+
) -> AsyncIterator[tuple[TaskEventID, TaskEvent]]: ...
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Annotated
2+
3+
from pydantic import BaseModel, BeforeValidator
4+
5+
6+
def _normalize_data(v: str | list[str]) -> list[str]:
7+
if isinstance(v, str):
8+
lines = v.splitlines()
9+
return lines if lines else [""]
10+
return v
11+
12+
13+
class SSEEvent(BaseModel):
14+
id: str | None = None
15+
event: str | None = None
16+
data: Annotated[str | list[str], BeforeValidator(_normalize_data)]
17+
retry: int | None = None
18+
19+
def serialize(self) -> bytes:
20+
lines = []
21+
if self.id is not None:
22+
lines.append(f"id: {self.id}")
23+
if self.event is not None:
24+
lines.append(f"event: {self.event}")
25+
lines.extend(f"data: {line}" for line in self.data)
26+
if self.retry is not None:
27+
lines.append(f"retry: {self.retry}")
28+
29+
payload = "\n".join(lines) + "\n\n"
30+
return payload.encode("utf-8")

services/web/server/src/simcore_service_webserver/storage/_rest.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
create_data_response,
5353
create_event_stream_response,
5454
)
55-
from servicelib.celery.models import TaskEvent, TaskFilter
55+
from servicelib.celery.models import TaskFilter
5656
from servicelib.common_headers import X_FORWARDED_PROTO
5757
from servicelib.rabbitmq.rpc_interfaces.storage.paths import (
5858
compute_path_size as remote_compute_path_size,
@@ -66,6 +66,7 @@
6666
)
6767
from servicelib.request_keys import RQT_USERID_KEY
6868
from servicelib.rest_responses import unwrap_envelope
69+
from servicelib.sse import SSEEvent
6970
from simcore_service_webserver.celery import get_task_manager
7071
from simcore_service_webserver.storage._rest_schemas import StreamHeaders
7172
from yarl import URL
@@ -583,14 +584,6 @@ class _PathParams(BaseModel):
583584
)
584585

585586

586-
def _format_sse(event: TaskEvent, event_name) -> bytes:
587-
sse = ""
588-
if event_name:
589-
sse += f"event: {event_name}\n"
590-
sse += f"data: {event.model_dump_json()}\n\n"
591-
return sse.encode("utf-8")
592-
593-
594587
@routes.get(
595588
_storage_locations_prefix + "/{location_id}/search/{job_id}/stream",
596589
name="stream_search",
@@ -614,12 +607,14 @@ class _PathParams(BaseModel):
614607
)
615608

616609
async def event_generator():
617-
async for event in task_manager.consume_task_events(
610+
async for event_id, event in task_manager.consume_task_events(
618611
task_filter=TaskFilter.model_validate(task_filter.model_dump()),
619612
task_uuid=path_params.job_id,
620613
last_id=header_params.last_event_id,
621614
):
622-
yield _format_sse(event, event_name=event.type)
615+
yield SSEEvent(
616+
id=event_id, event=event.type, data=event.model_dump_json()
617+
).serialize()
623618
if event.type == "status" and getattr(event, "data", None) in (
624619
"done",
625620
"error",

0 commit comments

Comments
 (0)