|
5 | 5 | # pylint:disable=too-many-arguments |
6 | 6 | # pylint:disable=protected-access |
7 | 7 |
|
| 8 | +import asyncio |
8 | 9 | import json |
9 | 10 | import logging |
10 | 11 | import secrets |
11 | 12 | from collections.abc import AsyncIterator, Awaitable, Callable |
12 | 13 | from dataclasses import dataclass |
| 14 | +from datetime import timedelta |
| 15 | +from pathlib import Path |
13 | 16 | from typing import Any |
14 | 17 | from unittest import mock |
15 | 18 |
|
|
18 | 21 | import simcore_service_webserver.db_listener |
19 | 22 | import simcore_service_webserver.db_listener._db_comp_tasks_listening_task |
20 | 23 | from aiohttp.test_utils import TestClient |
| 24 | +from common_library.async_tools import delayed_start |
21 | 25 | from faker import Faker |
22 | 26 | from models_library.projects import ProjectAtDB |
| 27 | +from models_library.projects_nodes import InputsDict |
23 | 28 | from pytest_mock import MockType |
24 | 29 | from pytest_mock.plugin import MockerFixture |
| 30 | +from pytest_simcore.helpers.logging_tools import log_context |
25 | 31 | from pytest_simcore.helpers.webserver_users import UserInfoDict |
26 | 32 | from simcore_postgres_database.models.comp_pipeline import StateType |
27 | 33 | from simcore_postgres_database.models.comp_tasks import NodeClass, comp_tasks |
|
30 | 36 | create_comp_tasks_listening_task, |
31 | 37 | ) |
32 | 38 | from sqlalchemy.ext.asyncio import AsyncEngine |
| 39 | +from tenacity import stop_after_attempt |
33 | 40 | from tenacity.asyncio import AsyncRetrying |
34 | 41 | from tenacity.before_sleep import before_sleep_log |
35 | 42 | from tenacity.retry import retry_if_exception_type |
@@ -208,3 +215,164 @@ async def test_db_listener_triggers_on_event_with_multiple_tasks( |
208 | 215 | ), f"_get_changed_comp_task_row was not called with task_id={updated_task_id}. Calls: {spied_get_changed_comp_task_row.call_args_list}" |
209 | 216 | else: |
210 | 217 | spied_get_changed_comp_task_row.assert_not_called() |
| 218 | + |
| 219 | + |
| 220 | +@pytest.fixture |
| 221 | +def fake_2connected_jupyterlabs_workbench(tests_data_dir: Path) -> dict[str, Any]: |
| 222 | + fpath = tests_data_dir / "workbench_2connected_jupyterlabs.json" |
| 223 | + assert fpath.exists() |
| 224 | + return json.loads(fpath.read_text()) |
| 225 | + |
| 226 | + |
| 227 | +@pytest.fixture |
| 228 | +async def mock_dynamic_service_rpc( |
| 229 | + mocker: MockerFixture, |
| 230 | +) -> mock.AsyncMock: |
| 231 | + """ |
| 232 | + Mocks the dynamic service RPC calls to avoid actual service calls during tests. |
| 233 | + """ |
| 234 | + return mocker.patch( |
| 235 | + "servicelib.rabbitmq.rpc_interfaces.dynamic_scheduler.services.retrieve_inputs", |
| 236 | + autospec=True, |
| 237 | + ) |
| 238 | + |
| 239 | + |
| 240 | +async def _check_for_stability( |
| 241 | + function: Callable[..., Awaitable[None]], *args, **kwargs |
| 242 | +) -> None: |
| 243 | + async for attempt in AsyncRetrying( |
| 244 | + stop=stop_after_attempt(5), |
| 245 | + wait=wait_fixed(1), |
| 246 | + retry=retry_if_exception_type(), |
| 247 | + reraise=True, |
| 248 | + ): |
| 249 | + with attempt: # noqa: SIM117 |
| 250 | + with log_context( |
| 251 | + logging.INFO, |
| 252 | + msg=f"check stability of {function.__name__} {attempt.retry_state.retry_object.statistics}", |
| 253 | + ) as log_ctx: |
| 254 | + await function(*args, **kwargs) |
| 255 | + log_ctx.logger.info( |
| 256 | + "stable for %s...", attempt.retry_state.seconds_since_start |
| 257 | + ) |
| 258 | + |
| 259 | + |
| 260 | +@pytest.mark.parametrize("user_role", [UserRole.USER]) |
| 261 | +async def test_db_listener_upgrades_projects_row_correctly( |
| 262 | + with_started_listening_task: None, |
| 263 | + mock_dynamic_service_rpc: mock.AsyncMock, |
| 264 | + sqlalchemy_async_engine: AsyncEngine, |
| 265 | + logged_user: UserInfoDict, |
| 266 | + project: Callable[..., Awaitable[ProjectAtDB]], |
| 267 | + fake_2connected_jupyterlabs_workbench: dict[str, Any], |
| 268 | + pipeline: Callable[..., dict[str, Any]], |
| 269 | + comp_task: Callable[..., dict[str, Any]], |
| 270 | + spied_get_changed_comp_task_row: MockType, |
| 271 | + faker: Faker, |
| 272 | +): |
| 273 | + some_project = await project( |
| 274 | + logged_user, workbench=fake_2connected_jupyterlabs_workbench |
| 275 | + ) |
| 276 | + |
| 277 | + # create the corresponding comp_task entries for the project workbench |
| 278 | + pipeline(project_id=f"{some_project.uuid}") |
| 279 | + tasks = [ |
| 280 | + comp_task( |
| 281 | + project_id=f"{some_project.uuid}", |
| 282 | + node_id=node_id, |
| 283 | + outputs=node_data.get("outputs", {}), |
| 284 | + node_class=( |
| 285 | + NodeClass.INTERACTIVE |
| 286 | + if "dynamic" in node_data["key"] |
| 287 | + else NodeClass.COMPUTATIONAL |
| 288 | + ), |
| 289 | + inputs=node_data.get("inputs", InputsDict()), |
| 290 | + ) |
| 291 | + for node_id, node_data in fake_2connected_jupyterlabs_workbench.items() |
| 292 | + ] |
| 293 | + assert len(tasks) == 2, "Expected two tasks for the two JupyterLab nodes" |
| 294 | + first_jupyter_task = tasks[0] |
| 295 | + second_jupyter_task = tasks[1] |
| 296 | + assert ( |
| 297 | + len(second_jupyter_task["inputs"]) > 0 |
| 298 | + ), "Expected inputs for the second JupyterLab task" |
| 299 | + number_of_inputs_linked = len(second_jupyter_task["inputs"]) |
| 300 | + |
| 301 | + # simulate a concurrent change in all the outputs of first jupyterlab |
| 302 | + async def _update_first_jupyter_task_output( |
| 303 | + port_index: int, data: dict[str, Any] |
| 304 | + ) -> None: |
| 305 | + with log_context(logging.INFO, msg=f"Updating output {port_index + 1}"): |
| 306 | + async with sqlalchemy_async_engine.begin() as conn: |
| 307 | + result = await conn.execute( |
| 308 | + comp_tasks.select() |
| 309 | + .with_only_columns([comp_tasks.c.outputs]) |
| 310 | + .where(comp_tasks.c.task_id == first_jupyter_task["task_id"]) |
| 311 | + .with_for_update() |
| 312 | + ) |
| 313 | + row = result.first() |
| 314 | + current_outputs = row[0] if row and row[0] else {} |
| 315 | + |
| 316 | + # Update/add the new key while preserving existing keys |
| 317 | + current_outputs[f"output_{port_index + 1}"] = data |
| 318 | + |
| 319 | + # Write back the updated outputs |
| 320 | + await conn.execute( |
| 321 | + comp_tasks.update() |
| 322 | + .values(outputs=current_outputs) |
| 323 | + .where(comp_tasks.c.task_id == first_jupyter_task["task_id"]) |
| 324 | + ) |
| 325 | + |
| 326 | + @delayed_start(timedelta(seconds=2)) |
| 327 | + async def _change_outputs_sequentially(sleep: float = 0.1) -> None: |
| 328 | + """ |
| 329 | + Sequentially updates the outputs of the second JupyterLab task to trigger the dynamic service RPC. |
| 330 | + """ |
| 331 | + for i in range(number_of_inputs_linked): |
| 332 | + await _update_first_jupyter_task_output(i, {"data": i}) |
| 333 | + await asyncio.sleep(sleep) |
| 334 | + |
| 335 | + # this runs in a task |
| 336 | + sequential_task = asyncio.create_task(_change_outputs_sequentially(5)) |
| 337 | + assert sequential_task is not None, "Failed to create the sequential task" |
| 338 | + |
| 339 | + async def _check_retrieve_rpc_called(expected_ports_retrieved: int) -> None: |
| 340 | + async for attempt in AsyncRetrying( |
| 341 | + stop=stop_after_delay(60), |
| 342 | + wait=wait_fixed(1), |
| 343 | + retry=retry_if_exception_type(AssertionError), |
| 344 | + reraise=True, |
| 345 | + ): |
| 346 | + with attempt: # noqa: SIM117 |
| 347 | + with log_context( |
| 348 | + logging.INFO, |
| 349 | + msg=f"Checking if dynamic service retrieve RPC was called and " |
| 350 | + f"all expected ports were retrieved {expected_ports_retrieved} " |
| 351 | + f"times, {attempt.retry_state.retry_object.statistics}", |
| 352 | + ) as log_ctx: |
| 353 | + if mock_dynamic_service_rpc.call_count > 0: |
| 354 | + log_ctx.logger.info( |
| 355 | + "call arguments: %s", |
| 356 | + mock_dynamic_service_rpc.call_args_list, |
| 357 | + ) |
| 358 | + # Assert that the dynamic service RPC was called |
| 359 | + assert ( |
| 360 | + mock_dynamic_service_rpc.call_count > 0 |
| 361 | + ), "Dynamic service retrieve RPC was not called" |
| 362 | + # now get we check which ports were retrieved, we expect all of them |
| 363 | + all_ports = set() |
| 364 | + for call in mock_dynamic_service_rpc.call_args_list: |
| 365 | + retrieved_ports = call[1]["port_keys"] |
| 366 | + all_ports.update(retrieved_ports) |
| 367 | + assert len(all_ports) == expected_ports_retrieved, ( |
| 368 | + f"Expected {expected_ports_retrieved} ports to be retrieved, " |
| 369 | + f"but got {len(all_ports)}: {all_ports}" |
| 370 | + ) |
| 371 | + log_ctx.logger.info( |
| 372 | + "Dynamic service retrieve RPC was called with all expected ports!" |
| 373 | + ) |
| 374 | + |
| 375 | + await _check_for_stability(_check_retrieve_rpc_called, number_of_inputs_linked) |
| 376 | + await asyncio.wait_for(sequential_task, timeout=60) |
| 377 | + assert sequential_task.done(), "Sequential task did not complete" |
| 378 | + assert not sequential_task.cancelled(), "Sequential task was cancelled unexpectedly" |
0 commit comments