|
| 1 | +import contextlib |
1 | 2 | import json |
2 | 3 | import logging |
| 4 | +import re |
| 5 | +from collections import defaultdict |
3 | 6 | from contextlib import ExitStack |
4 | | -from dataclasses import dataclass |
| 7 | +from dataclasses import dataclass, field |
5 | 8 | from enum import Enum, unique |
6 | 9 | from typing import Any, Final |
7 | 10 |
|
8 | | -from playwright.sync_api import WebSocket |
| 11 | +from playwright.sync_api import FrameLocator, Page, Request, WebSocket, expect |
9 | 12 | from pytest_simcore.logging_utils import log_context |
10 | 13 |
|
11 | 14 | SECOND: Final[int] = 1000 |
12 | 15 | MINUTE: Final[int] = 60 * SECOND |
| 16 | +NODE_START_REQUEST_PATTERN: Final[re.Pattern[str]] = re.compile( |
| 17 | + r"/projects/[^/]+/nodes/[^:]+:start" |
| 18 | +) |
13 | 19 |
|
14 | 20 |
|
15 | 21 | @unique |
@@ -42,6 +48,28 @@ def is_running(self) -> bool: |
42 | 48 | ) |
43 | 49 |
|
44 | 50 |
|
| 51 | +@unique |
| 52 | +class NodeProgressType(str, Enum): |
| 53 | + # NOTE: this is a partial duplicate of models_library/rabbitmq_messages.py |
| 54 | + # It must remain as such until that module is pydantic V2 compatible |
| 55 | + CLUSTER_UP_SCALING = "CLUSTER_UP_SCALING" |
| 56 | + SERVICE_INPUTS_PULLING = "SERVICE_INPUTS_PULLING" |
| 57 | + SIDECARS_PULLING = "SIDECARS_PULLING" |
| 58 | + SERVICE_OUTPUTS_PULLING = "SERVICE_OUTPUTS_PULLING" |
| 59 | + SERVICE_STATE_PULLING = "SERVICE_STATE_PULLING" |
| 60 | + SERVICE_IMAGES_PULLING = "SERVICE_IMAGES_PULLING" |
| 61 | + |
| 62 | + @classmethod |
| 63 | + def required_types_for_started_service(cls) -> set["NodeProgressType"]: |
| 64 | + return { |
| 65 | + NodeProgressType.SERVICE_INPUTS_PULLING, |
| 66 | + NodeProgressType.SIDECARS_PULLING, |
| 67 | + NodeProgressType.SERVICE_OUTPUTS_PULLING, |
| 68 | + NodeProgressType.SERVICE_STATE_PULLING, |
| 69 | + NodeProgressType.SERVICE_IMAGES_PULLING, |
| 70 | + } |
| 71 | + |
| 72 | + |
45 | 73 | class ServiceType(str, Enum): |
46 | 74 | DYNAMIC = "DYNAMIC" |
47 | 75 | COMPUTATIONAL = "COMPUTATIONAL" |
@@ -84,6 +112,28 @@ def retrieve_project_state_from_decoded_message(event: SocketIOEvent) -> Running |
84 | 112 | return RunningState(event.obj["data"]["state"]["value"]) |
85 | 113 |
|
86 | 114 |
|
| 115 | +@dataclass(frozen=True, slots=True, kw_only=True) |
| 116 | +class NodeProgressEvent: |
| 117 | + node_id: str |
| 118 | + progress_type: NodeProgressType |
| 119 | + current_progress: float |
| 120 | + total_progress: float |
| 121 | + |
| 122 | + |
| 123 | +def retrieve_node_progress_from_decoded_message( |
| 124 | + event: SocketIOEvent, |
| 125 | +) -> NodeProgressEvent: |
| 126 | + assert event.name == _OSparcMessages.NODE_PROGRESS.value |
| 127 | + assert "progress_type" in event.obj |
| 128 | + assert "progress_report" in event.obj |
| 129 | + return NodeProgressEvent( |
| 130 | + node_id=event.obj["node_id"], |
| 131 | + progress_type=NodeProgressType(event.obj["progress_type"]), |
| 132 | + current_progress=float(event.obj["progress_report"]["actual_value"]), |
| 133 | + total_progress=float(event.obj["progress_report"]["total"]), |
| 134 | + ) |
| 135 | + |
| 136 | + |
87 | 137 | @dataclass |
88 | 138 | class SocketIOProjectClosedWaiter: |
89 | 139 | def __call__(self, message: str) -> bool: |
@@ -139,6 +189,44 @@ def __call__(self, message: str) -> None: |
139 | 189 | print("WS Message:", decoded_message.name, decoded_message.obj) |
140 | 190 |
|
141 | 191 |
|
| 192 | +@dataclass |
| 193 | +class SocketIONodeProgressCompleteWaiter: |
| 194 | + node_id: str |
| 195 | + _current_progress: dict[NodeProgressType, float] = field( |
| 196 | + default_factory=defaultdict |
| 197 | + ) |
| 198 | + |
| 199 | + def __call__(self, message: str) -> bool: |
| 200 | + with log_context(logging.DEBUG, msg=f"handling websocket {message=}") as ctx: |
| 201 | + # socket.io encodes messages like so |
| 202 | + # https://stackoverflow.com/questions/24564877/what-do-these-numbers-mean-in-socket-io-payload |
| 203 | + if message.startswith(_SOCKETIO_MESSAGE_PREFIX): |
| 204 | + decoded_message = decode_socketio_42_message(message) |
| 205 | + if decoded_message.name == _OSparcMessages.NODE_PROGRESS.value: |
| 206 | + node_progress_event = retrieve_node_progress_from_decoded_message( |
| 207 | + decoded_message |
| 208 | + ) |
| 209 | + if node_progress_event.node_id == self.node_id: |
| 210 | + self._current_progress[node_progress_event.progress_type] = ( |
| 211 | + node_progress_event.current_progress |
| 212 | + / node_progress_event.total_progress |
| 213 | + ) |
| 214 | + ctx.logger.info( |
| 215 | + "current startup progress: %s", |
| 216 | + f"{json.dumps({k:round(v,1) for k,v in self._current_progress.items()})}", |
| 217 | + ) |
| 218 | + |
| 219 | + return all( |
| 220 | + progress_type in self._current_progress |
| 221 | + for progress_type in NodeProgressType.required_types_for_started_service() |
| 222 | + ) and all( |
| 223 | + round(progress, 1) == 1.0 |
| 224 | + for progress in self._current_progress.values() |
| 225 | + ) |
| 226 | + |
| 227 | + return False |
| 228 | + |
| 229 | + |
142 | 230 | def wait_for_pipeline_state( |
143 | 231 | current_state: RunningState, |
144 | 232 | *, |
@@ -187,3 +275,52 @@ def on_web_socket_default_handler(ws) -> None: |
187 | 275 | ws.on("framesent", lambda payload: ctx.logger.info("⬇️ %s", payload)) |
188 | 276 | ws.on("framereceived", lambda payload: ctx.logger.info("⬆️ %s", payload)) |
189 | 277 | ws.on("close", lambda payload: stack.close()) # noqa: ARG005 |
| 278 | + |
| 279 | + |
| 280 | +def _node_started_predicate(request: Request) -> bool: |
| 281 | + return bool( |
| 282 | + re.search(NODE_START_REQUEST_PATTERN, request.url) |
| 283 | + and request.method.upper() == "POST" |
| 284 | + ) |
| 285 | + |
| 286 | + |
| 287 | +def _trigger_service_start_if_button_available(page: Page, node_id: str) -> None: |
| 288 | + # wait for the start button to auto-disappear if it is still around after the timeout, then we click it |
| 289 | + with log_context(logging.INFO, msg="trigger start button if needed") as ctx: |
| 290 | + start_button_locator = page.get_by_test_id(f"Start_{node_id}") |
| 291 | + with contextlib.suppress(AssertionError, TimeoutError): |
| 292 | + expect(start_button_locator).to_be_visible(timeout=5000) |
| 293 | + expect(start_button_locator).to_be_enabled(timeout=5000) |
| 294 | + with page.expect_request(_node_started_predicate): |
| 295 | + start_button_locator.click() |
| 296 | + ctx.logger.info("triggered start button") |
| 297 | + |
| 298 | + |
| 299 | +def wait_for_service_running( |
| 300 | + *, |
| 301 | + page: Page, |
| 302 | + node_id: str, |
| 303 | + websocket: WebSocket, |
| 304 | + timeout: int, |
| 305 | +) -> FrameLocator: |
| 306 | + """NOTE: if the service was already started this will not work as some of the required websocket events will not be emitted again |
| 307 | + In which case this will need further adjutment""" |
| 308 | + |
| 309 | + waiter = SocketIONodeProgressCompleteWaiter(node_id=node_id) |
| 310 | + with ( |
| 311 | + log_context(logging.INFO, msg="Waiting for node to run"), |
| 312 | + websocket.expect_event("framereceived", waiter, timeout=timeout), |
| 313 | + ): |
| 314 | + _trigger_service_start_if_button_available(page, node_id) |
| 315 | + return page.frame_locator(f'[osparc-test-id="iframe_{node_id}"]') |
| 316 | + |
| 317 | + |
| 318 | +def app_mode_trigger_next_app(page: Page) -> None: |
| 319 | + with ( |
| 320 | + log_context(logging.INFO, msg="triggering next app"), |
| 321 | + page.expect_request(_node_started_predicate), |
| 322 | + ): |
| 323 | + # Move to next step (this auto starts the next service) |
| 324 | + next_button_locator = page.get_by_test_id("AppMode_NextBtn") |
| 325 | + if next_button_locator.is_visible() and next_button_locator.is_enabled(): |
| 326 | + page.get_by_test_id("AppMode_NextBtn").click() |
0 commit comments