Skip to content

Commit 48c8d5c

Browse files
committed
fixes
1 parent 8a9ab47 commit 48c8d5c

File tree

14 files changed

+141
-107
lines changed

14 files changed

+141
-107
lines changed

backend/app/db/repositories/sse_repository.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ def __init__(self, database: Database) -> None:
1515
self.mapper = SSEMapper()
1616

1717
async def get_execution_status(self, execution_id: str) -> SSEExecutionStatusDomain | None:
18-
doc = await self.executions_collection.find_one(
19-
{"execution_id": execution_id}, {"status": 1, "_id": 0}
20-
)
18+
doc = await self.executions_collection.find_one({"execution_id": execution_id}, {"status": 1, "_id": 0})
2119
if not doc:
2220
return None
2321
return SSEExecutionStatusDomain(

backend/app/schemas_pydantic/execution.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
from pydantic import BaseModel, ConfigDict, Field, model_validator
88

9-
from app.domain.enums.storage import ExecutionErrorType
109
from app.domain.enums.events import EventType
1110
from app.domain.enums.execution import ExecutionStatus
11+
from app.domain.enums.storage import ExecutionErrorType
1212
from app.settings import get_settings
1313

1414

@@ -80,8 +80,9 @@ def validate_runtime_supported(self) -> "ExecutionRequest": # noqa: D401
8080
if not (lang_info := runtimes.get(self.lang)):
8181
raise ValueError(f"Language '{self.lang}' not supported. Supported: {list(runtimes.keys())}")
8282
if self.lang_version not in lang_info.versions:
83-
raise ValueError(f"Version '{self.lang_version}' not supported for {self.lang}. "
84-
f"Supported: {lang_info.versions}")
83+
raise ValueError(
84+
f"Version '{self.lang_version}' not supported for {self.lang}. Supported: {lang_info.versions}"
85+
)
8586
return self
8687

8788

backend/app/schemas_pydantic/sse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from app.schemas_pydantic.execution import ExecutionResult, ResourceUsage
1010

1111
# Control event types sent by SSE (not from Kafka)
12-
SSEControlEventType = Literal['connected', 'heartbeat', 'shutdown', 'status', 'error']
12+
SSEControlEventType = Literal["connected", "heartbeat", "shutdown", "status", "error"]
1313

1414
# Type variable for generic Redis message parsing
15-
T = TypeVar('T', bound=BaseModel)
15+
T = TypeVar("T", bound=BaseModel)
1616

1717

1818
class SSEExecutionEventData(BaseModel):
@@ -69,7 +69,7 @@ class RedisSSEMessage(BaseModel):
6969

7070

7171
# Control event types for notification SSE stream
72-
SSENotificationControlEventType = Literal['connected', 'heartbeat', 'notification']
72+
SSENotificationControlEventType = Literal["connected", "heartbeat", "notification"]
7373

7474

7575
class SSENotificationEventData(BaseModel):

backend/app/services/notification_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
ExecutionTimeoutEvent,
3535
)
3636
from app.infrastructure.kafka.mappings import get_topic_for_event
37+
from app.schemas_pydantic.sse import RedisNotificationMessage
3738
from app.services.event_bus import EventBusManager
3839
from app.services.kafka_event_service import KafkaEventService
39-
from app.schemas_pydantic.sse import RedisNotificationMessage
4040
from app.services.sse.redis_bus import SSERedisBus
4141
from app.settings import Settings, get_settings
4242

backend/app/services/sse/redis_bus.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def __init__(self, pubsub: redis.client.PubSub, channel: str) -> None:
1818
self._pubsub = pubsub
1919
self._channel = channel
2020

21-
async def get(self, model: Type[T], timeout: float = 0.5) -> T | None:
21+
async def get(self, model: Type[T]) -> T | None:
2222
"""Get next typed message from the subscription."""
23-
msg = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=timeout)
23+
msg = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5)
2424
if not msg or msg.get("type") != "message":
2525
return None
2626
try:
@@ -39,7 +39,7 @@ class SSERedisBus:
3939
"""Redis-backed pub/sub bus for SSE event fan-out across workers."""
4040

4141
def __init__(
42-
self, redis_client: redis.Redis, exec_prefix: str = "sse:exec:", notif_prefix: str = "sse:notif:"
42+
self, redis_client: redis.Redis, exec_prefix: str = "sse:exec:", notif_prefix: str = "sse:notif:"
4343
) -> None:
4444
self._redis = redis_client
4545
self._exec_prefix = exec_prefix

backend/app/services/sse/sse_service.py

Lines changed: 82 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,28 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn
5151

5252
shutdown_event = await self.shutdown_manager.register_connection(execution_id, connection_id)
5353
if shutdown_event is None:
54-
yield self._format_sse_event(SSEExecutionEventData(
55-
event_type="error",
56-
execution_id=execution_id,
57-
timestamp=datetime.now(timezone.utc).isoformat(),
58-
error="Server is shutting down",
59-
))
54+
yield self._format_sse_event(
55+
SSEExecutionEventData(
56+
event_type="error",
57+
execution_id=execution_id,
58+
timestamp=datetime.now(timezone.utc).isoformat(),
59+
error="Server is shutting down",
60+
)
61+
)
6062
return
6163

6264
subscription = None
6365
try:
6466
# Start opening subscription concurrently, then yield handshake
6567
sub_task = asyncio.create_task(self.sse_bus.open_subscription(execution_id))
66-
yield self._format_sse_event(SSEExecutionEventData(
67-
event_type="connected",
68-
execution_id=execution_id,
69-
timestamp=datetime.now(timezone.utc).isoformat(),
70-
connection_id=connection_id,
71-
))
68+
yield self._format_sse_event(
69+
SSEExecutionEventData(
70+
event_type="connected",
71+
execution_id=execution_id,
72+
timestamp=datetime.now(timezone.utc).isoformat(),
73+
connection_id=connection_id,
74+
)
75+
)
7276

7377
# Complete Redis subscription after handshake
7478
logger.info(f"Opening Redis subscription for execution {execution_id}")
@@ -77,12 +81,14 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn
7781

7882
initial_status = await self.repository.get_execution_status(execution_id)
7983
if initial_status:
80-
yield self._format_sse_event(SSEExecutionEventData(
81-
event_type="status",
82-
execution_id=initial_status.execution_id,
83-
timestamp=initial_status.timestamp,
84-
status=initial_status.status,
85-
))
84+
yield self._format_sse_event(
85+
SSEExecutionEventData(
86+
event_type="status",
87+
execution_id=initial_status.execution_id,
88+
timestamp=initial_status.timestamp,
89+
status=initial_status.status,
90+
)
91+
)
8692
self.metrics.record_sse_message_sent("executions", "status")
8793

8894
async for event_data in self._stream_events_redis(
@@ -109,26 +115,30 @@ async def _stream_events_redis(
109115
last_heartbeat = datetime.now(timezone.utc)
110116
while True:
111117
if shutdown_event.is_set():
112-
yield self._format_sse_event(SSEExecutionEventData(
113-
event_type="shutdown",
114-
execution_id=execution_id,
115-
timestamp=datetime.now(timezone.utc).isoformat(),
116-
message="Server is shutting down",
117-
grace_period=30,
118-
))
118+
yield self._format_sse_event(
119+
SSEExecutionEventData(
120+
event_type="shutdown",
121+
execution_id=execution_id,
122+
timestamp=datetime.now(timezone.utc).isoformat(),
123+
message="Server is shutting down",
124+
grace_period=30,
125+
)
126+
)
119127
break
120128

121129
now = datetime.now(timezone.utc)
122130
if include_heartbeat and (now - last_heartbeat).total_seconds() >= self.heartbeat_interval:
123-
yield self._format_sse_event(SSEExecutionEventData(
124-
event_type="heartbeat",
125-
execution_id=execution_id,
126-
timestamp=now.isoformat(),
127-
message="SSE connection active",
128-
))
131+
yield self._format_sse_event(
132+
SSEExecutionEventData(
133+
event_type="heartbeat",
134+
execution_id=execution_id,
135+
timestamp=now.isoformat(),
136+
message="SSE connection active",
137+
)
138+
)
129139
last_heartbeat = now
130140

131-
msg: RedisSSEMessage | None = await subscription.get(RedisSSEMessage, timeout=0.5)
141+
msg: RedisSSEMessage | None = await subscription.get(RedisSSEMessage)
132142
if not msg:
133143
continue
134144

@@ -145,36 +155,38 @@ async def _stream_events_redis(
145155
# Ignore malformed messages
146156
continue
147157

148-
async def _build_sse_event_from_redis(
149-
self, execution_id: str, msg: RedisSSEMessage
150-
) -> SSEExecutionEventData:
158+
async def _build_sse_event_from_redis(self, execution_id: str, msg: RedisSSEMessage) -> SSEExecutionEventData:
151159
"""Build typed SSE event from Redis message."""
152160
result: ExecutionResult | None = None
153161
if msg.event_type == EventType.RESULT_STORED:
154162
exec_domain = await self.repository.get_execution(execution_id)
155163
if exec_domain:
156164
result = ExecutionResult.model_validate(exec_domain)
157165

158-
return SSEExecutionEventData.model_validate({
159-
**msg.data,
160-
"event_type": msg.event_type,
161-
"execution_id": execution_id,
162-
"type": msg.event_type,
163-
"result": result,
164-
})
166+
return SSEExecutionEventData.model_validate(
167+
{
168+
**msg.data,
169+
"event_type": msg.event_type,
170+
"execution_id": execution_id,
171+
"type": msg.event_type,
172+
"result": result,
173+
}
174+
)
165175

166176
async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[str, Any], None]:
167177
subscription = None
168178

169179
try:
170180
# Start opening subscription concurrently, then yield handshake
171181
sub_task = asyncio.create_task(self.sse_bus.open_notification_subscription(user_id))
172-
yield self._format_notification_event(SSENotificationEventData(
173-
event_type="connected",
174-
user_id=user_id,
175-
timestamp=datetime.now(timezone.utc).isoformat(),
176-
message="Connected to notification stream",
177-
))
182+
yield self._format_notification_event(
183+
SSENotificationEventData(
184+
event_type="connected",
185+
user_id=user_id,
186+
timestamp=datetime.now(timezone.utc).isoformat(),
187+
message="Connected to notification stream",
188+
)
189+
)
178190

179191
# Complete Redis subscription after handshake
180192
subscription = await sub_task
@@ -184,28 +196,32 @@ async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[
184196
# Heartbeat
185197
now = datetime.now(timezone.utc)
186198
if (now - last_heartbeat).total_seconds() >= self.heartbeat_interval:
187-
yield self._format_notification_event(SSENotificationEventData(
188-
event_type="heartbeat",
189-
user_id=user_id,
190-
timestamp=now.isoformat(),
191-
message="Notification stream active",
192-
))
199+
yield self._format_notification_event(
200+
SSENotificationEventData(
201+
event_type="heartbeat",
202+
user_id=user_id,
203+
timestamp=now.isoformat(),
204+
message="Notification stream active",
205+
)
206+
)
193207
last_heartbeat = now
194208

195209
# Forward notification messages as SSE data
196-
redis_msg = await subscription.get(RedisNotificationMessage, timeout=0.5)
210+
redis_msg = await subscription.get(RedisNotificationMessage)
197211
if redis_msg:
198-
yield self._format_notification_event(SSENotificationEventData(
199-
event_type="notification",
200-
notification_id=redis_msg.notification_id,
201-
severity=redis_msg.severity,
202-
status=redis_msg.status,
203-
tags=redis_msg.tags,
204-
subject=redis_msg.subject,
205-
body=redis_msg.body,
206-
action_url=redis_msg.action_url,
207-
created_at=redis_msg.created_at,
208-
))
212+
yield self._format_notification_event(
213+
SSENotificationEventData(
214+
event_type="notification",
215+
notification_id=redis_msg.notification_id,
216+
severity=redis_msg.severity,
217+
status=redis_msg.status,
218+
tags=redis_msg.tags,
219+
subject=redis_msg.subject,
220+
body=redis_msg.body,
221+
action_url=redis_msg.action_url,
222+
created_at=redis_msg.created_at,
223+
)
224+
)
209225
finally:
210226
try:
211227
if subscription is not None:

backend/tests/integration/notifications/test_notification_sse.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from app.domain.enums.notification import NotificationChannel, NotificationSeverity
7+
from app.schemas_pydantic.sse import RedisNotificationMessage
78
from app.services.notification_service import NotificationService
89
from app.services.sse.redis_bus import SSERedisBus
910
from tests.helpers.eventually import eventually
@@ -34,13 +35,13 @@ async def test_in_app_notification_published_to_sse(scope) -> None: # type: ign
3435
)
3536

3637
# Receive published SSE payload
37-
async def _recv():
38-
m = await sub.get(timeout=0.5)
38+
async def _recv() -> RedisNotificationMessage:
39+
m = await sub.get(RedisNotificationMessage)
3940
assert m is not None
4041
return m
4142

4243
msg = await eventually(_recv, timeout=5.0, interval=0.1)
4344
# Basic shape assertions
44-
assert msg.get("subject") == "Hello"
45-
assert msg.get("body") == "World"
46-
assert "notification_id" in msg
45+
assert msg.subject == "Hello"
46+
assert msg.body == "World"
47+
assert msg.notification_id

backend/tests/integration/services/sse/test_partitioned_event_router_integration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from app.events.schema.schema_registry import SchemaRegistryManager
99
from tests.helpers import make_execution_requested_event
1010
from app.infrastructure.kafka.events.pod import PodCreatedEvent
11+
from app.schemas_pydantic.sse import RedisSSEMessage
1112
from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge
1213
from app.services.sse.redis_bus import SSERedisBus
1314
from app.settings import Settings
@@ -42,12 +43,12 @@ async def test_router_bridges_to_redis(redis_client) -> None: # type: ignore[va
4243
await handler(ev)
4344

4445
async def _recv():
45-
m = await subscription.get(timeout=0.2)
46+
m = await subscription.get(RedisSSEMessage)
4647
assert m is not None
4748
return m
4849

4950
msg = await eventually(_recv, timeout=2.0, interval=0.05)
50-
assert msg.get("event_type") == str(ev.event_type)
51+
assert str(msg.event_type) == str(ev.event_type)
5152

5253

5354
@pytest.mark.asyncio

backend/tests/integration/test_sse_routes.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import pytest
77
from httpx import AsyncClient
88

9-
from app.schemas_pydantic.sse import SSEHealthResponse
9+
from app.domain.enums.notification import NotificationSeverity, NotificationStatus
10+
from app.schemas_pydantic.sse import RedisNotificationMessage, SSEHealthResponse
1011
from app.infrastructure.kafka.events.pod import PodCreatedEvent
1112
from app.infrastructure.kafka.events.metadata import EventMetadata
1213
from app.services.sse.redis_bus import SSERedisBus
@@ -73,7 +74,17 @@ async def _connected() -> None:
7374
await eventually(_connected, timeout=2.0, interval=0.05)
7475

7576
# Publish a notification
76-
await bus.publish_notification(user_id, {"subject": "Hello", "body": "World", "event_type": "notification"})
77+
notification = RedisNotificationMessage(
78+
notification_id=f"notif-{uuid4().hex[:8]}",
79+
severity=NotificationSeverity.MEDIUM,
80+
status=NotificationStatus.PENDING,
81+
tags=[],
82+
subject="Hello",
83+
body="World",
84+
action_url="",
85+
created_at="2024-01-01T00:00:00Z",
86+
)
87+
await bus.publish_notification(user_id, notification)
7788

7889
# Wait for collection to complete
7990
try:

0 commit comments

Comments
 (0)