|
1 | | -import asyncio |
| 1 | +"""SSE integration tests - precise verification of Redis pub/sub and stream behavior.""" |
| 2 | + |
2 | 3 | import json |
3 | 4 | from contextlib import aclosing |
4 | 5 | from datetime import datetime, timezone |
5 | 6 | from typing import Any |
6 | 7 | from uuid import uuid4 |
7 | 8 |
|
8 | 9 | import pytest |
| 10 | +from app.domain.enums.events import EventType |
9 | 11 | from app.domain.enums.notification import NotificationSeverity, NotificationStatus |
| 12 | +from app.domain.enums.sse import SSEControlEvent, SSENotificationEvent |
10 | 13 | from app.infrastructure.kafka.events.metadata import AvroEventMetadata |
11 | 14 | from app.infrastructure.kafka.events.pod import PodCreatedEvent |
12 | | -from app.schemas_pydantic.sse import RedisNotificationMessage, SSEHealthResponse |
| 15 | +from app.schemas_pydantic.sse import ( |
| 16 | + RedisNotificationMessage, |
| 17 | + RedisSSEMessage, |
| 18 | + SSEHealthResponse, |
| 19 | +) |
13 | 20 | from app.services.sse.redis_bus import SSERedisBus |
14 | 21 | from app.services.sse.sse_service import SSEService |
15 | 22 | from dishka import AsyncContainer |
16 | 23 | from httpx import AsyncClient |
17 | 24 |
|
18 | 25 |
|
19 | 26 | @pytest.mark.integration |
20 | | -class TestSSERoutes: |
21 | | - """SSE routes tested with deterministic event-driven reads (no polling).""" |
| 27 | +class TestSSEAuth: |
| 28 | + """SSE endpoints require authentication.""" |
| 29 | + |
| 30 | + @pytest.mark.asyncio |
| 31 | + async def test_notification_stream_requires_auth(self, client: AsyncClient) -> None: |
| 32 | + assert (await client.get("/api/v1/events/notifications/stream")).status_code == 401 |
| 33 | + |
| 34 | + @pytest.mark.asyncio |
| 35 | + async def test_execution_stream_requires_auth(self, client: AsyncClient) -> None: |
| 36 | + assert (await client.get(f"/api/v1/events/executions/{uuid4()}")).status_code == 401 |
22 | 37 |
|
23 | 38 | @pytest.mark.asyncio |
24 | | - async def test_sse_requires_authentication(self, client: AsyncClient) -> None: |
25 | | - r = await client.get("/api/v1/events/notifications/stream") |
26 | | - assert r.status_code == 401 |
27 | | - detail = r.json().get("detail", "").lower() |
28 | | - assert any(x in detail for x in ("not authenticated", "unauthorized", "login")) |
| 39 | + async def test_health_requires_auth(self, client: AsyncClient) -> None: |
| 40 | + assert (await client.get("/api/v1/events/health")).status_code == 401 |
29 | 41 |
|
30 | | - exec_id = str(uuid4()) |
31 | | - r = await client.get(f"/api/v1/events/executions/{exec_id}") |
32 | | - assert r.status_code == 401 |
33 | 42 |
|
34 | | - r = await client.get("/api/v1/events/health") |
35 | | - assert r.status_code == 401 |
| 43 | +@pytest.mark.integration |
| 44 | +class TestSSEHealth: |
| 45 | + """SSE health endpoint.""" |
36 | 46 |
|
37 | 47 | @pytest.mark.asyncio |
38 | | - async def test_sse_health_status(self, test_user: AsyncClient) -> None: |
| 48 | + async def test_returns_valid_health_status(self, test_user: AsyncClient) -> None: |
39 | 49 | r = await test_user.get("/api/v1/events/health") |
40 | 50 | assert r.status_code == 200 |
41 | | - health = SSEHealthResponse(**r.json()) |
| 51 | + health = SSEHealthResponse.model_validate(r.json()) |
42 | 52 | assert health.status in ("healthy", "degraded", "unhealthy", "draining") |
43 | | - assert isinstance(health.active_connections, int) and health.active_connections >= 0 |
| 53 | + assert health.active_connections >= 0 |
| 54 | + |
| 55 | + |
| 56 | +@pytest.mark.integration |
| 57 | +class TestRedisPubSubExecution: |
| 58 | + """Redis pub/sub for execution events - verifies message structure and delivery.""" |
44 | 59 |
|
45 | 60 | @pytest.mark.asyncio |
46 | | - async def test_notification_stream_service(self, scope: AsyncContainer, test_user: AsyncClient) -> None: |
47 | | - """Test SSE notification stream directly through service (httpx doesn't support SSE streaming).""" |
48 | | - sse_service: SSEService = await scope.get(SSEService) |
| 61 | + async def test_publish_wraps_event_in_redis_message(self, scope: AsyncContainer) -> None: |
| 62 | + """publish_event wraps BaseEvent in RedisSSEMessage with correct structure.""" |
49 | 63 | bus: SSERedisBus = await scope.get(SSERedisBus) |
50 | | - user_id = f"user-{uuid4().hex[:8]}" |
| 64 | + exec_id = f"exec-{uuid4().hex[:8]}" |
51 | 65 |
|
52 | | - events: list[dict[str, Any]] = [] |
53 | | - notification_received = False |
| 66 | + subscription = await bus.open_subscription(exec_id) |
54 | 67 |
|
55 | | - async with aclosing(sse_service.create_notification_stream(user_id)) as stream: |
56 | | - try: |
57 | | - async with asyncio.timeout(3.0): |
58 | | - async for event in stream: |
59 | | - if "data" not in event: |
60 | | - continue |
61 | | - data = json.loads(event["data"]) |
62 | | - events.append(data) |
63 | | - |
64 | | - # Wait for "subscribed" event - Redis subscription is now ready |
65 | | - if data.get("event_type") == "subscribed": |
66 | | - notification = RedisNotificationMessage( |
67 | | - notification_id=f"notif-{uuid4().hex[:8]}", |
68 | | - severity=NotificationSeverity.MEDIUM, |
69 | | - status=NotificationStatus.PENDING, |
70 | | - tags=[], |
71 | | - subject="Hello", |
72 | | - body="World", |
73 | | - action_url="", |
74 | | - created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), |
75 | | - ) |
76 | | - await bus.publish_notification(user_id, notification) |
77 | | - |
78 | | - # Stop when we receive the notification |
79 | | - if data.get("event_type") == "notification" and data.get("subject") == "Hello": |
80 | | - notification_received = True |
81 | | - break |
82 | | - except TimeoutError: |
83 | | - pass |
84 | | - |
85 | | - assert notification_received, f"Expected notification, got events: {events}" |
| 68 | + event = PodCreatedEvent( |
| 69 | + execution_id=exec_id, |
| 70 | + pod_name="test-pod", |
| 71 | + namespace="test-ns", |
| 72 | + metadata=AvroEventMetadata(service_name="test", service_version="1.0"), |
| 73 | + ) |
| 74 | + await bus.publish_event(exec_id, event) |
| 75 | + |
| 76 | + # Verify the wrapper structure |
| 77 | + received: RedisSSEMessage | None = await subscription.get(RedisSSEMessage) |
| 78 | + await subscription.close() |
| 79 | + |
| 80 | + assert received is not None |
| 81 | + assert received.event_type == EventType.POD_CREATED |
| 82 | + assert received.execution_id == exec_id |
| 83 | + assert received.data["pod_name"] == "test-pod" |
| 84 | + assert received.data["namespace"] == "test-ns" |
86 | 85 |
|
87 | 86 | @pytest.mark.asyncio |
88 | | - async def test_execution_event_stream_service(self, scope: AsyncContainer, test_user: AsyncClient) -> None: |
89 | | - """Test SSE execution stream directly through service (httpx doesn't support SSE streaming).""" |
90 | | - exec_id = f"e-{uuid4().hex[:8]}" |
91 | | - user_id = f"user-{uuid4().hex[:8]}" |
| 87 | + async def test_channel_isolation(self, scope: AsyncContainer) -> None: |
| 88 | + """Different execution_ids use isolated channels.""" |
| 89 | + bus: SSERedisBus = await scope.get(SSERedisBus) |
| 90 | + exec_a, exec_b = f"exec-a-{uuid4().hex[:6]}", f"exec-b-{uuid4().hex[:6]}" |
92 | 91 |
|
93 | | - sse_service: SSEService = await scope.get(SSEService) |
| 92 | + sub_a = await bus.open_subscription(exec_a) |
| 93 | + sub_b = await bus.open_subscription(exec_b) |
| 94 | + |
| 95 | + event = PodCreatedEvent( |
| 96 | + execution_id=exec_a, |
| 97 | + pod_name="pod-a", |
| 98 | + namespace="default", |
| 99 | + metadata=AvroEventMetadata(service_name="test", service_version="1"), |
| 100 | + ) |
| 101 | + await bus.publish_event(exec_a, event) |
| 102 | + |
| 103 | + received_a = await sub_a.get(RedisSSEMessage) |
| 104 | + received_b = await sub_b.get(RedisSSEMessage) |
| 105 | + |
| 106 | + await sub_a.close() |
| 107 | + await sub_b.close() |
| 108 | + |
| 109 | + assert received_a is not None |
| 110 | + assert received_a.data["pod_name"] == "pod-a" |
| 111 | + assert received_b is None # B should not receive A's message |
| 112 | + |
| 113 | + |
| 114 | +@pytest.mark.integration |
| 115 | +class TestRedisPubSubNotification: |
| 116 | + """Redis pub/sub for notifications - verifies message structure and delivery.""" |
| 117 | + |
| 118 | + @pytest.mark.asyncio |
| 119 | + async def test_publish_sends_notification_directly(self, scope: AsyncContainer) -> None: |
| 120 | + """publish_notification sends RedisNotificationMessage JSON directly.""" |
94 | 121 | bus: SSERedisBus = await scope.get(SSERedisBus) |
| 122 | + user_id = f"user-{uuid4().hex[:8]}" |
95 | 123 |
|
96 | | - events: list[dict[str, Any]] = [] |
97 | | - pod_event_received = False |
| 124 | + subscription = await bus.open_notification_subscription(user_id) |
98 | 125 |
|
99 | | - async with aclosing(sse_service.create_execution_stream(exec_id, user_id)) as stream: |
100 | | - try: |
101 | | - async with asyncio.timeout(3.0): |
102 | | - async for event in stream: |
103 | | - if "data" not in event: |
104 | | - continue |
105 | | - data = json.loads(event["data"]) |
106 | | - events.append(data) |
107 | | - |
108 | | - # Wait for "subscribed" event - Redis subscription is now ready |
109 | | - if data.get("event_type") == "subscribed": |
110 | | - ev = PodCreatedEvent( |
111 | | - execution_id=exec_id, |
112 | | - pod_name=f"executor-{exec_id}", |
113 | | - namespace="default", |
114 | | - metadata=AvroEventMetadata(service_name="tests", service_version="1"), |
115 | | - ) |
116 | | - await bus.publish_event(exec_id, ev) |
117 | | - |
118 | | - # Stop when we receive the pod event |
119 | | - if data.get("event_type") == "pod_created": |
120 | | - pod_event_received = True |
121 | | - break |
122 | | - except TimeoutError: |
123 | | - pass |
124 | | - |
125 | | - assert pod_event_received, f"Expected pod_created event, got events: {events}" |
| 126 | + notification = RedisNotificationMessage( |
| 127 | + notification_id="notif-123", |
| 128 | + severity=NotificationSeverity.HIGH, |
| 129 | + status=NotificationStatus.PENDING, |
| 130 | + tags=["urgent", "system"], |
| 131 | + subject="Test Alert", |
| 132 | + body="This is a test notification", |
| 133 | + action_url="https://example.com/action", |
| 134 | + created_at=datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc), |
| 135 | + ) |
| 136 | + await bus.publish_notification(user_id, notification) |
| 137 | + |
| 138 | + received: RedisNotificationMessage | None = await subscription.get(RedisNotificationMessage) |
| 139 | + await subscription.close() |
| 140 | + |
| 141 | + assert received is not None |
| 142 | + assert received.notification_id == "notif-123" |
| 143 | + assert received.severity == NotificationSeverity.HIGH |
| 144 | + assert received.status == NotificationStatus.PENDING |
| 145 | + assert received.tags == ["urgent", "system"] |
| 146 | + assert received.subject == "Test Alert" |
| 147 | + assert received.body == "This is a test notification" |
| 148 | + assert received.action_url == "https://example.com/action" |
126 | 149 |
|
127 | 150 | @pytest.mark.asyncio |
128 | | - async def test_sse_route_requires_auth(self, client: AsyncClient) -> None: |
129 | | - """Test that SSE routes require authentication (HTTP-level test only).""" |
130 | | - r = await client.get("/api/v1/events/notifications/stream") |
131 | | - assert r.status_code == 401 |
| 151 | + async def test_user_channel_isolation(self, scope: AsyncContainer) -> None: |
| 152 | + """Different user_ids use isolated channels.""" |
| 153 | + bus: SSERedisBus = await scope.get(SSERedisBus) |
| 154 | + user_a, user_b = f"user-a-{uuid4().hex[:6]}", f"user-b-{uuid4().hex[:6]}" |
| 155 | + |
| 156 | + sub_a = await bus.open_notification_subscription(user_a) |
| 157 | + sub_b = await bus.open_notification_subscription(user_b) |
| 158 | + |
| 159 | + notification = RedisNotificationMessage( |
| 160 | + notification_id="for-user-a", |
| 161 | + severity=NotificationSeverity.LOW, |
| 162 | + status=NotificationStatus.PENDING, |
| 163 | + tags=[], |
| 164 | + subject="Private", |
| 165 | + body="For user A only", |
| 166 | + action_url="", |
| 167 | + created_at=datetime.now(timezone.utc), |
| 168 | + ) |
| 169 | + await bus.publish_notification(user_a, notification) |
| 170 | + |
| 171 | + received_a = await sub_a.get(RedisNotificationMessage) |
| 172 | + received_b = await sub_b.get(RedisNotificationMessage) |
132 | 173 |
|
133 | | - exec_id = str(uuid4()) |
134 | | - r = await client.get(f"/api/v1/events/executions/{exec_id}") |
135 | | - assert r.status_code == 401 |
| 174 | + await sub_a.close() |
| 175 | + await sub_b.close() |
| 176 | + |
| 177 | + assert received_a is not None |
| 178 | + assert received_a.notification_id == "for-user-a" |
| 179 | + assert received_b is None # B should not receive A's notification |
| 180 | + |
| 181 | + |
| 182 | +@pytest.mark.integration |
| 183 | +class TestSSEStreamEvents: |
| 184 | + """SSE stream control events - verifies event structure without pub/sub.""" |
136 | 185 |
|
137 | 186 | @pytest.mark.asyncio |
138 | | - async def test_sse_endpoint_returns_correct_headers(self, test_user: AsyncClient) -> None: |
139 | | - """Test SSE health endpoint works (streaming tests done via service).""" |
140 | | - r = await test_user.get("/api/v1/events/health") |
141 | | - assert r.status_code == 200 |
142 | | - assert isinstance(r.json(), dict) |
| 187 | + async def test_notification_stream_yields_connected_then_subscribed(self, scope: AsyncContainer) -> None: |
| 188 | + """Notification stream yields CONNECTED and SUBSCRIBED with correct fields.""" |
| 189 | + sse_service: SSEService = await scope.get(SSEService) |
| 190 | + user_id = f"user-{uuid4().hex[:8]}" |
| 191 | + |
| 192 | + events: list[dict[str, Any]] = [] |
| 193 | + async with aclosing(sse_service.create_notification_stream(user_id)) as stream: |
| 194 | + async for raw in stream: |
| 195 | + if "data" in raw: |
| 196 | + events.append(json.loads(raw["data"])) |
| 197 | + if len(events) >= 2: |
| 198 | + break |
| 199 | + |
| 200 | + # Verify CONNECTED event structure |
| 201 | + connected = events[0] |
| 202 | + assert connected["event_type"] == SSENotificationEvent.CONNECTED |
| 203 | + assert connected["user_id"] == user_id |
| 204 | + assert "timestamp" in connected |
| 205 | + assert connected["message"] == "Connected to notification stream" |
| 206 | + |
| 207 | + # Verify SUBSCRIBED event structure |
| 208 | + subscribed = events[1] |
| 209 | + assert subscribed["event_type"] == SSENotificationEvent.SUBSCRIBED |
| 210 | + assert subscribed["user_id"] == user_id |
| 211 | + assert "timestamp" in subscribed |
| 212 | + assert subscribed["message"] == "Redis subscription established" |
143 | 213 |
|
144 | 214 | @pytest.mark.asyncio |
145 | | - async def test_multiple_concurrent_sse_service_connections( |
146 | | - self, scope: AsyncContainer, test_user: AsyncClient |
147 | | - ) -> None: |
148 | | - """Test multiple concurrent SSE connections through the service.""" |
| 215 | + async def test_execution_stream_yields_connected_then_subscribed(self, scope: AsyncContainer) -> None: |
| 216 | + """Execution stream yields CONNECTED and SUBSCRIBED with correct fields.""" |
149 | 217 | sse_service: SSEService = await scope.get(SSEService) |
| 218 | + exec_id = f"exec-{uuid4().hex[:8]}" |
| 219 | + user_id = f"user-{uuid4().hex[:8]}" |
150 | 220 |
|
151 | | - async def create_and_verify_stream(user_id: str) -> bool: |
152 | | - async with aclosing(sse_service.create_notification_stream(user_id)) as stream: |
153 | | - async for event in stream: |
154 | | - if "data" in event: |
155 | | - data = json.loads(event["data"]) |
156 | | - if data.get("event_type") == "connected": |
157 | | - return True |
| 221 | + events: list[dict[str, Any]] = [] |
| 222 | + async with aclosing(sse_service.create_execution_stream(exec_id, user_id)) as stream: |
| 223 | + async for raw in stream: |
| 224 | + if "data" in raw: |
| 225 | + events.append(json.loads(raw["data"])) |
| 226 | + if len(events) >= 2: |
158 | 227 | break |
159 | | - return False |
160 | 228 |
|
161 | | - results = await asyncio.gather( |
162 | | - create_and_verify_stream("user1"), |
163 | | - create_and_verify_stream("user2"), |
164 | | - create_and_verify_stream("user3"), |
165 | | - return_exceptions=True, |
| 229 | + # Verify CONNECTED event structure |
| 230 | + connected = events[0] |
| 231 | + assert connected["event_type"] == SSEControlEvent.CONNECTED |
| 232 | + assert connected["execution_id"] == exec_id |
| 233 | + assert "connection_id" in connected |
| 234 | + assert connected["connection_id"].startswith(f"sse_{exec_id}_") |
| 235 | + assert "timestamp" in connected |
| 236 | + |
| 237 | + # Verify SUBSCRIBED event structure |
| 238 | + subscribed = events[1] |
| 239 | + assert subscribed["event_type"] == SSEControlEvent.SUBSCRIBED |
| 240 | + assert subscribed["execution_id"] == exec_id |
| 241 | + assert "timestamp" in subscribed |
| 242 | + assert subscribed["message"] == "Redis subscription established" |
| 243 | + |
| 244 | + @pytest.mark.asyncio |
| 245 | + async def test_concurrent_streams_get_unique_connection_ids(self, scope: AsyncContainer) -> None: |
| 246 | + """Each stream connection gets a unique connection_id.""" |
| 247 | + import asyncio |
| 248 | + |
| 249 | + sse_service: SSEService = await scope.get(SSEService) |
| 250 | + exec_id = f"exec-{uuid4().hex[:8]}" |
| 251 | + |
| 252 | + async def get_connection_id(user_id: str) -> str: |
| 253 | + async with aclosing(sse_service.create_execution_stream(exec_id, user_id)) as stream: |
| 254 | + async for raw in stream: |
| 255 | + if "data" in raw: |
| 256 | + data = json.loads(raw["data"]) |
| 257 | + if data.get("event_type") == SSEControlEvent.CONNECTED: |
| 258 | + return data["connection_id"] |
| 259 | + return "" |
| 260 | + |
| 261 | + conn_ids = await asyncio.gather( |
| 262 | + get_connection_id("user-1"), |
| 263 | + get_connection_id("user-2"), |
| 264 | + get_connection_id("user-3"), |
166 | 265 | ) |
167 | 266 |
|
168 | | - successful = sum(1 for r in results if r is True) |
169 | | - assert successful >= 2 |
| 267 | + # All connection IDs should be unique |
| 268 | + assert len(set(conn_ids)) == 3 |
| 269 | + assert all(cid.startswith(f"sse_{exec_id}_") for cid in conn_ids) |
0 commit comments