Skip to content

Commit aa22029

Browse files
committed
sse tests update
1 parent 0a474f2 commit aa22029

File tree

1 file changed

+219
-119
lines changed

1 file changed

+219
-119
lines changed
Lines changed: 219 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,169 +1,269 @@
1-
import asyncio
1+
"""SSE integration tests - precise verification of Redis pub/sub and stream behavior."""
2+
23
import json
34
from contextlib import aclosing
45
from datetime import datetime, timezone
56
from typing import Any
67
from uuid import uuid4
78

89
import pytest
10+
from app.domain.enums.events import EventType
911
from app.domain.enums.notification import NotificationSeverity, NotificationStatus
12+
from app.domain.enums.sse import SSEControlEvent, SSENotificationEvent
1013
from app.infrastructure.kafka.events.metadata import AvroEventMetadata
1114
from app.infrastructure.kafka.events.pod import PodCreatedEvent
12-
from app.schemas_pydantic.sse import RedisNotificationMessage, SSEHealthResponse
15+
from app.schemas_pydantic.sse import (
16+
RedisNotificationMessage,
17+
RedisSSEMessage,
18+
SSEHealthResponse,
19+
)
1320
from app.services.sse.redis_bus import SSERedisBus
1421
from app.services.sse.sse_service import SSEService
1522
from dishka import AsyncContainer
1623
from httpx import AsyncClient
1724

1825

1926
@pytest.mark.integration
20-
class TestSSERoutes:
21-
"""SSE routes tested with deterministic event-driven reads (no polling)."""
27+
class TestSSEAuth:
28+
"""SSE endpoints require authentication."""
29+
30+
@pytest.mark.asyncio
31+
async def test_notification_stream_requires_auth(self, client: AsyncClient) -> None:
32+
assert (await client.get("/api/v1/events/notifications/stream")).status_code == 401
33+
34+
@pytest.mark.asyncio
35+
async def test_execution_stream_requires_auth(self, client: AsyncClient) -> None:
36+
assert (await client.get(f"/api/v1/events/executions/{uuid4()}")).status_code == 401
2237

2338
@pytest.mark.asyncio
24-
async def test_sse_requires_authentication(self, client: AsyncClient) -> None:
25-
r = await client.get("/api/v1/events/notifications/stream")
26-
assert r.status_code == 401
27-
detail = r.json().get("detail", "").lower()
28-
assert any(x in detail for x in ("not authenticated", "unauthorized", "login"))
39+
async def test_health_requires_auth(self, client: AsyncClient) -> None:
40+
assert (await client.get("/api/v1/events/health")).status_code == 401
2941

30-
exec_id = str(uuid4())
31-
r = await client.get(f"/api/v1/events/executions/{exec_id}")
32-
assert r.status_code == 401
3342

34-
r = await client.get("/api/v1/events/health")
35-
assert r.status_code == 401
43+
@pytest.mark.integration
44+
class TestSSEHealth:
45+
"""SSE health endpoint."""
3646

3747
@pytest.mark.asyncio
38-
async def test_sse_health_status(self, test_user: AsyncClient) -> None:
48+
async def test_returns_valid_health_status(self, test_user: AsyncClient) -> None:
3949
r = await test_user.get("/api/v1/events/health")
4050
assert r.status_code == 200
41-
health = SSEHealthResponse(**r.json())
51+
health = SSEHealthResponse.model_validate(r.json())
4252
assert health.status in ("healthy", "degraded", "unhealthy", "draining")
43-
assert isinstance(health.active_connections, int) and health.active_connections >= 0
53+
assert health.active_connections >= 0
54+
55+
56+
@pytest.mark.integration
57+
class TestRedisPubSubExecution:
58+
"""Redis pub/sub for execution events - verifies message structure and delivery."""
4459

4560
@pytest.mark.asyncio
46-
async def test_notification_stream_service(self, scope: AsyncContainer, test_user: AsyncClient) -> None:
47-
"""Test SSE notification stream directly through service (httpx doesn't support SSE streaming)."""
48-
sse_service: SSEService = await scope.get(SSEService)
61+
async def test_publish_wraps_event_in_redis_message(self, scope: AsyncContainer) -> None:
62+
"""publish_event wraps BaseEvent in RedisSSEMessage with correct structure."""
4963
bus: SSERedisBus = await scope.get(SSERedisBus)
50-
user_id = f"user-{uuid4().hex[:8]}"
64+
exec_id = f"exec-{uuid4().hex[:8]}"
5165

52-
events: list[dict[str, Any]] = []
53-
notification_received = False
66+
subscription = await bus.open_subscription(exec_id)
5467

55-
async with aclosing(sse_service.create_notification_stream(user_id)) as stream:
56-
try:
57-
async with asyncio.timeout(3.0):
58-
async for event in stream:
59-
if "data" not in event:
60-
continue
61-
data = json.loads(event["data"])
62-
events.append(data)
63-
64-
# Wait for "subscribed" event - Redis subscription is now ready
65-
if data.get("event_type") == "subscribed":
66-
notification = RedisNotificationMessage(
67-
notification_id=f"notif-{uuid4().hex[:8]}",
68-
severity=NotificationSeverity.MEDIUM,
69-
status=NotificationStatus.PENDING,
70-
tags=[],
71-
subject="Hello",
72-
body="World",
73-
action_url="",
74-
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
75-
)
76-
await bus.publish_notification(user_id, notification)
77-
78-
# Stop when we receive the notification
79-
if data.get("event_type") == "notification" and data.get("subject") == "Hello":
80-
notification_received = True
81-
break
82-
except TimeoutError:
83-
pass
84-
85-
assert notification_received, f"Expected notification, got events: {events}"
68+
event = PodCreatedEvent(
69+
execution_id=exec_id,
70+
pod_name="test-pod",
71+
namespace="test-ns",
72+
metadata=AvroEventMetadata(service_name="test", service_version="1.0"),
73+
)
74+
await bus.publish_event(exec_id, event)
75+
76+
# Verify the wrapper structure
77+
received: RedisSSEMessage | None = await subscription.get(RedisSSEMessage)
78+
await subscription.close()
79+
80+
assert received is not None
81+
assert received.event_type == EventType.POD_CREATED
82+
assert received.execution_id == exec_id
83+
assert received.data["pod_name"] == "test-pod"
84+
assert received.data["namespace"] == "test-ns"
8685

8786
@pytest.mark.asyncio
88-
async def test_execution_event_stream_service(self, scope: AsyncContainer, test_user: AsyncClient) -> None:
89-
"""Test SSE execution stream directly through service (httpx doesn't support SSE streaming)."""
90-
exec_id = f"e-{uuid4().hex[:8]}"
91-
user_id = f"user-{uuid4().hex[:8]}"
87+
async def test_channel_isolation(self, scope: AsyncContainer) -> None:
88+
"""Different execution_ids use isolated channels."""
89+
bus: SSERedisBus = await scope.get(SSERedisBus)
90+
exec_a, exec_b = f"exec-a-{uuid4().hex[:6]}", f"exec-b-{uuid4().hex[:6]}"
9291

93-
sse_service: SSEService = await scope.get(SSEService)
92+
sub_a = await bus.open_subscription(exec_a)
93+
sub_b = await bus.open_subscription(exec_b)
94+
95+
event = PodCreatedEvent(
96+
execution_id=exec_a,
97+
pod_name="pod-a",
98+
namespace="default",
99+
metadata=AvroEventMetadata(service_name="test", service_version="1"),
100+
)
101+
await bus.publish_event(exec_a, event)
102+
103+
received_a = await sub_a.get(RedisSSEMessage)
104+
received_b = await sub_b.get(RedisSSEMessage)
105+
106+
await sub_a.close()
107+
await sub_b.close()
108+
109+
assert received_a is not None
110+
assert received_a.data["pod_name"] == "pod-a"
111+
assert received_b is None # B should not receive A's message
112+
113+
114+
@pytest.mark.integration
115+
class TestRedisPubSubNotification:
116+
"""Redis pub/sub for notifications - verifies message structure and delivery."""
117+
118+
@pytest.mark.asyncio
119+
async def test_publish_sends_notification_directly(self, scope: AsyncContainer) -> None:
120+
"""publish_notification sends RedisNotificationMessage JSON directly."""
94121
bus: SSERedisBus = await scope.get(SSERedisBus)
122+
user_id = f"user-{uuid4().hex[:8]}"
95123

96-
events: list[dict[str, Any]] = []
97-
pod_event_received = False
124+
subscription = await bus.open_notification_subscription(user_id)
98125

99-
async with aclosing(sse_service.create_execution_stream(exec_id, user_id)) as stream:
100-
try:
101-
async with asyncio.timeout(3.0):
102-
async for event in stream:
103-
if "data" not in event:
104-
continue
105-
data = json.loads(event["data"])
106-
events.append(data)
107-
108-
# Wait for "subscribed" event - Redis subscription is now ready
109-
if data.get("event_type") == "subscribed":
110-
ev = PodCreatedEvent(
111-
execution_id=exec_id,
112-
pod_name=f"executor-{exec_id}",
113-
namespace="default",
114-
metadata=AvroEventMetadata(service_name="tests", service_version="1"),
115-
)
116-
await bus.publish_event(exec_id, ev)
117-
118-
# Stop when we receive the pod event
119-
if data.get("event_type") == "pod_created":
120-
pod_event_received = True
121-
break
122-
except TimeoutError:
123-
pass
124-
125-
assert pod_event_received, f"Expected pod_created event, got events: {events}"
126+
notification = RedisNotificationMessage(
127+
notification_id="notif-123",
128+
severity=NotificationSeverity.HIGH,
129+
status=NotificationStatus.PENDING,
130+
tags=["urgent", "system"],
131+
subject="Test Alert",
132+
body="This is a test notification",
133+
action_url="https://example.com/action",
134+
created_at=datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc),
135+
)
136+
await bus.publish_notification(user_id, notification)
137+
138+
received: RedisNotificationMessage | None = await subscription.get(RedisNotificationMessage)
139+
await subscription.close()
140+
141+
assert received is not None
142+
assert received.notification_id == "notif-123"
143+
assert received.severity == NotificationSeverity.HIGH
144+
assert received.status == NotificationStatus.PENDING
145+
assert received.tags == ["urgent", "system"]
146+
assert received.subject == "Test Alert"
147+
assert received.body == "This is a test notification"
148+
assert received.action_url == "https://example.com/action"
126149

127150
@pytest.mark.asyncio
128-
async def test_sse_route_requires_auth(self, client: AsyncClient) -> None:
129-
"""Test that SSE routes require authentication (HTTP-level test only)."""
130-
r = await client.get("/api/v1/events/notifications/stream")
131-
assert r.status_code == 401
151+
async def test_user_channel_isolation(self, scope: AsyncContainer) -> None:
152+
"""Different user_ids use isolated channels."""
153+
bus: SSERedisBus = await scope.get(SSERedisBus)
154+
user_a, user_b = f"user-a-{uuid4().hex[:6]}", f"user-b-{uuid4().hex[:6]}"
155+
156+
sub_a = await bus.open_notification_subscription(user_a)
157+
sub_b = await bus.open_notification_subscription(user_b)
158+
159+
notification = RedisNotificationMessage(
160+
notification_id="for-user-a",
161+
severity=NotificationSeverity.LOW,
162+
status=NotificationStatus.PENDING,
163+
tags=[],
164+
subject="Private",
165+
body="For user A only",
166+
action_url="",
167+
created_at=datetime.now(timezone.utc),
168+
)
169+
await bus.publish_notification(user_a, notification)
170+
171+
received_a = await sub_a.get(RedisNotificationMessage)
172+
received_b = await sub_b.get(RedisNotificationMessage)
132173

133-
exec_id = str(uuid4())
134-
r = await client.get(f"/api/v1/events/executions/{exec_id}")
135-
assert r.status_code == 401
174+
await sub_a.close()
175+
await sub_b.close()
176+
177+
assert received_a is not None
178+
assert received_a.notification_id == "for-user-a"
179+
assert received_b is None # B should not receive A's notification
180+
181+
182+
@pytest.mark.integration
183+
class TestSSEStreamEvents:
184+
"""SSE stream control events - verifies event structure without pub/sub."""
136185

137186
@pytest.mark.asyncio
138-
async def test_sse_endpoint_returns_correct_headers(self, test_user: AsyncClient) -> None:
139-
"""Test SSE health endpoint works (streaming tests done via service)."""
140-
r = await test_user.get("/api/v1/events/health")
141-
assert r.status_code == 200
142-
assert isinstance(r.json(), dict)
187+
async def test_notification_stream_yields_connected_then_subscribed(self, scope: AsyncContainer) -> None:
188+
"""Notification stream yields CONNECTED and SUBSCRIBED with correct fields."""
189+
sse_service: SSEService = await scope.get(SSEService)
190+
user_id = f"user-{uuid4().hex[:8]}"
191+
192+
events: list[dict[str, Any]] = []
193+
async with aclosing(sse_service.create_notification_stream(user_id)) as stream:
194+
async for raw in stream:
195+
if "data" in raw:
196+
events.append(json.loads(raw["data"]))
197+
if len(events) >= 2:
198+
break
199+
200+
# Verify CONNECTED event structure
201+
connected = events[0]
202+
assert connected["event_type"] == SSENotificationEvent.CONNECTED
203+
assert connected["user_id"] == user_id
204+
assert "timestamp" in connected
205+
assert connected["message"] == "Connected to notification stream"
206+
207+
# Verify SUBSCRIBED event structure
208+
subscribed = events[1]
209+
assert subscribed["event_type"] == SSENotificationEvent.SUBSCRIBED
210+
assert subscribed["user_id"] == user_id
211+
assert "timestamp" in subscribed
212+
assert subscribed["message"] == "Redis subscription established"
143213

144214
@pytest.mark.asyncio
145-
async def test_multiple_concurrent_sse_service_connections(
146-
self, scope: AsyncContainer, test_user: AsyncClient
147-
) -> None:
148-
"""Test multiple concurrent SSE connections through the service."""
215+
async def test_execution_stream_yields_connected_then_subscribed(self, scope: AsyncContainer) -> None:
216+
"""Execution stream yields CONNECTED and SUBSCRIBED with correct fields."""
149217
sse_service: SSEService = await scope.get(SSEService)
218+
exec_id = f"exec-{uuid4().hex[:8]}"
219+
user_id = f"user-{uuid4().hex[:8]}"
150220

151-
async def create_and_verify_stream(user_id: str) -> bool:
152-
async with aclosing(sse_service.create_notification_stream(user_id)) as stream:
153-
async for event in stream:
154-
if "data" in event:
155-
data = json.loads(event["data"])
156-
if data.get("event_type") == "connected":
157-
return True
221+
events: list[dict[str, Any]] = []
222+
async with aclosing(sse_service.create_execution_stream(exec_id, user_id)) as stream:
223+
async for raw in stream:
224+
if "data" in raw:
225+
events.append(json.loads(raw["data"]))
226+
if len(events) >= 2:
158227
break
159-
return False
160228

161-
results = await asyncio.gather(
162-
create_and_verify_stream("user1"),
163-
create_and_verify_stream("user2"),
164-
create_and_verify_stream("user3"),
165-
return_exceptions=True,
229+
# Verify CONNECTED event structure
230+
connected = events[0]
231+
assert connected["event_type"] == SSEControlEvent.CONNECTED
232+
assert connected["execution_id"] == exec_id
233+
assert "connection_id" in connected
234+
assert connected["connection_id"].startswith(f"sse_{exec_id}_")
235+
assert "timestamp" in connected
236+
237+
# Verify SUBSCRIBED event structure
238+
subscribed = events[1]
239+
assert subscribed["event_type"] == SSEControlEvent.SUBSCRIBED
240+
assert subscribed["execution_id"] == exec_id
241+
assert "timestamp" in subscribed
242+
assert subscribed["message"] == "Redis subscription established"
243+
244+
@pytest.mark.asyncio
245+
async def test_concurrent_streams_get_unique_connection_ids(self, scope: AsyncContainer) -> None:
246+
"""Each stream connection gets a unique connection_id."""
247+
import asyncio
248+
249+
sse_service: SSEService = await scope.get(SSEService)
250+
exec_id = f"exec-{uuid4().hex[:8]}"
251+
252+
async def get_connection_id(user_id: str) -> str:
253+
async with aclosing(sse_service.create_execution_stream(exec_id, user_id)) as stream:
254+
async for raw in stream:
255+
if "data" in raw:
256+
data = json.loads(raw["data"])
257+
if data.get("event_type") == SSEControlEvent.CONNECTED:
258+
return data["connection_id"]
259+
return ""
260+
261+
conn_ids = await asyncio.gather(
262+
get_connection_id("user-1"),
263+
get_connection_id("user-2"),
264+
get_connection_id("user-3"),
166265
)
167266

168-
successful = sum(1 for r in results if r is True)
169-
assert successful >= 2
267+
# All connection IDs should be unique
268+
assert len(set(conn_ids)) == 3
269+
assert all(cid.startswith(f"sse_{exec_id}_") for cid in conn_ids)

0 commit comments

Comments
 (0)