|
5 | 5 | # pylint:disable=too-many-arguments |
6 | 6 | # pylint:disable=protected-access |
7 | 7 |
|
| 8 | +from ast import Assert |
| 9 | +import asyncio |
| 10 | +from datetime import timedelta |
8 | 11 | import json |
9 | 12 | import logging |
10 | 13 | import secrets |
|
14 | 17 | from unittest import mock |
15 | 18 |
|
16 | 19 | import pytest |
| 20 | +from tenacity import stop_after_attempt |
| 21 | +from common_library.async_tools import delayed_start |
| 22 | +from models_library.projects_nodes import InputsDict |
| 23 | +from pytest_simcore.helpers.logging_tools import log_context |
17 | 24 | import simcore_service_webserver |
18 | 25 | import simcore_service_webserver.db_listener |
19 | 26 | import simcore_service_webserver.db_listener._db_comp_tasks_listening_task |
|
36 | 43 | from tenacity.stop import stop_after_delay |
37 | 44 | from tenacity.wait import wait_fixed |
38 | 45 |
|
| 46 | +from simcore_service_webserver.projects.models import ProjectDict |
| 47 | + |
39 | 48 | logger = logging.getLogger(__name__) |
40 | 49 |
|
41 | 50 |
|
@@ -205,6 +214,181 @@ async def test_db_listener_triggers_on_event_with_multiple_tasks( |
205 | 214 | assert any( |
206 | 215 | call.args[1] == updated_task_id |
207 | 216 | for call in spied_get_changed_comp_task_row.call_args_list |
208 | | - ), f"_get_changed_comp_task_row was not called with task_id={updated_task_id}. Calls: {spied_get_changed_comp_task_row.call_args_list}" |
| 217 | + ), ( |
| 218 | + f"_get_changed_comp_task_row was not called with task_id={updated_task_id}. Calls: {spied_get_changed_comp_task_row.call_args_list}" |
| 219 | + ) |
209 | 220 | else: |
210 | 221 | spied_get_changed_comp_task_row.assert_not_called() |
| 222 | + |
| 223 | + |
| 224 | +from pathlib import Path |
| 225 | + |
| 226 | + |
| 227 | +@pytest.fixture |
| 228 | +def fake_2connected_jupyterlabs_workbench(tests_data_dir: Path) -> dict[str, Any]: |
| 229 | + fpath = tests_data_dir / "workbench_2connected_jupyterlabs.json" |
| 230 | + assert fpath.exists() |
| 231 | + return json.loads(fpath.read_text()) |
| 232 | + |
| 233 | + |
| 234 | +@pytest.fixture |
| 235 | +async def mock_dynamic_service_rpc( |
| 236 | + mocker: MockerFixture, |
| 237 | +) -> mock.AsyncMock: |
| 238 | + """ |
| 239 | + Mocks the dynamic service RPC calls to avoid actual service calls during tests. |
| 240 | + """ |
| 241 | + return mocker.patch( |
| 242 | + "servicelib.rabbitmq.rpc_interfaces.dynamic_scheduler.services.retrieve_inputs", |
| 243 | + autospec=True, |
| 244 | + ) |
| 245 | + |
| 246 | + |
| 247 | +async def _check_for_stability( |
| 248 | + function: Callable[..., Awaitable[None]], *args, **kwargs |
| 249 | +) -> None: |
| 250 | + async for attempt in AsyncRetrying( |
| 251 | + stop=stop_after_attempt(5), |
| 252 | + wait=wait_fixed(1), |
| 253 | + retry=retry_if_exception_type(), |
| 254 | + reraise=True, |
| 255 | + ): |
| 256 | + with attempt: # noqa: SIM117 |
| 257 | + with log_context( |
| 258 | + logging.INFO, |
| 259 | + msg=f"check stability of {function.__name__} {attempt.retry_state.retry_object.statistics}", |
| 260 | + ) as log_ctx: |
| 261 | + await function(*args, **kwargs) |
| 262 | + log_ctx.logger.info( |
| 263 | + "stable for %s...", attempt.retry_state.seconds_since_start |
| 264 | + ) |
| 265 | + |
| 266 | + |
| 267 | +@pytest.mark.testit |
| 268 | +@pytest.mark.parametrize("user_role", [UserRole.USER]) |
| 269 | +async def test_db_listener_upgrades_projects_row_correctly( |
| 270 | + with_started_listening_task: None, |
| 271 | + mock_dynamic_service_rpc: mock.AsyncMock, |
| 272 | + sqlalchemy_async_engine: AsyncEngine, |
| 273 | + logged_user: UserInfoDict, |
| 274 | + project: Callable[..., Awaitable[ProjectAtDB]], |
| 275 | + fake_2connected_jupyterlabs_workbench: dict[str, Any], |
| 276 | + pipeline: Callable[..., dict[str, Any]], |
| 277 | + comp_task: Callable[..., dict[str, Any]], |
| 278 | + spied_get_changed_comp_task_row: MockType, |
| 279 | + faker: Faker, |
| 280 | +): |
| 281 | + some_project = await project( |
| 282 | + logged_user, workbench=fake_2connected_jupyterlabs_workbench |
| 283 | + ) |
| 284 | + |
| 285 | + # create the corresponding comp_task entries for the project workbench |
| 286 | + pipeline(project_id=f"{some_project.uuid}") |
| 287 | + tasks = [ |
| 288 | + comp_task( |
| 289 | + project_id=f"{some_project.uuid}", |
| 290 | + node_id=node_id, |
| 291 | + outputs=node_data.get("outputs", {}), |
| 292 | + node_class=NodeClass.INTERACTIVE |
| 293 | + if "dynamic" in node_data["key"] |
| 294 | + else NodeClass.COMPUTATIONAL, |
| 295 | + inputs=node_data.get("inputs", InputsDict()), |
| 296 | + ) |
| 297 | + for node_id, node_data in fake_2connected_jupyterlabs_workbench.items() |
| 298 | + ] |
| 299 | + assert len(tasks) == 2, "Expected two tasks for the two JupyterLab nodes" |
| 300 | + first_jupyter_task = tasks[0] |
| 301 | + second_jupyter_task = tasks[1] |
| 302 | + assert len(second_jupyter_task["inputs"]) > 0, ( |
| 303 | + "Expected inputs for the second JupyterLab task" |
| 304 | + ) |
| 305 | + number_of_inputs_linked = len(second_jupyter_task["inputs"]) |
| 306 | + |
| 307 | + # simulate a concurrent change in all the outputs of first jupyterlab |
| 308 | + async def _update_first_jupyter_task_output( |
| 309 | + port_index: int, data: dict[str, Any] |
| 310 | + ) -> None: |
| 311 | + with log_context(logging.INFO, msg=f"Updating output {port_index + 1}"): |
| 312 | + async with sqlalchemy_async_engine.begin() as conn: |
| 313 | + # For JSON columns, we need to use jsonb_set or fetch-modify-update |
| 314 | + # Since it's JSON (not JSONB), let's use the safer fetch-modify approach |
| 315 | + # Use SELECT FOR UPDATE to lock the row for concurrent access |
| 316 | + result = await conn.execute( |
| 317 | + comp_tasks.select() |
| 318 | + .with_only_columns([comp_tasks.c.outputs]) |
| 319 | + .where(comp_tasks.c.task_id == first_jupyter_task["task_id"]) |
| 320 | + .with_for_update() |
| 321 | + ) |
| 322 | + row = result.first() |
| 323 | + current_outputs = row[0] if row and row[0] else {} |
| 324 | + |
| 325 | + # Update/add the new key while preserving existing keys |
| 326 | + current_outputs[f"output_{port_index + 1}"] = data |
| 327 | + |
| 328 | + # Write back the updated outputs |
| 329 | + await conn.execute( |
| 330 | + comp_tasks.update() |
| 331 | + .values(outputs=current_outputs) |
| 332 | + .where(comp_tasks.c.task_id == first_jupyter_task["task_id"]) |
| 333 | + ) |
| 334 | + |
| 335 | + # await asyncio.gather( |
| 336 | + # *( |
| 337 | + # _update_first_jupyter_task_output(i, {"data": i}) |
| 338 | + # for i in range(number_of_inputs_linked) |
| 339 | + # ) |
| 340 | + # ) |
| 341 | + |
| 342 | + @delayed_start(timedelta(seconds=2)) |
| 343 | + async def _change_outputs_sequentially(sleep: float = 0.1) -> None: |
| 344 | + """ |
| 345 | + Sequentially updates the outputs of the second JupyterLab task to trigger the dynamic service RPC. |
| 346 | + """ |
| 347 | + for i in range(number_of_inputs_linked): |
| 348 | + await _update_first_jupyter_task_output(i, {"data": i}) |
| 349 | + await asyncio.sleep(sleep) |
| 350 | + |
| 351 | + # this runs in a task |
| 352 | + sequential_task = asyncio.create_task(_change_outputs_sequentially(5)) |
| 353 | + assert sequential_task is not None, "Failed to create the sequential task" |
| 354 | + |
| 355 | + async def _check_retrieve_rpc_called(expected_ports_retrieved: int) -> None: |
| 356 | + async for attempt in AsyncRetrying( |
| 357 | + stop=stop_after_delay(60), |
| 358 | + wait=wait_fixed(1), |
| 359 | + retry=retry_if_exception_type(AssertionError), |
| 360 | + reraise=True, |
| 361 | + ): |
| 362 | + with attempt: # noqa: SIM117 |
| 363 | + with log_context( |
| 364 | + logging.INFO, |
| 365 | + msg=f"Checking if dynamic service retrieve RPC was called and " |
| 366 | + f"all expected ports were retrieved {expected_ports_retrieved} " |
| 367 | + f"times, {attempt.retry_state.retry_object.statistics}", |
| 368 | + ) as log_ctx: |
| 369 | + if mock_dynamic_service_rpc.call_count > 0: |
| 370 | + log_ctx.logger.info( |
| 371 | + "call arguments: %s", |
| 372 | + mock_dynamic_service_rpc.call_args_list, |
| 373 | + ) |
| 374 | + # Assert that the dynamic service RPC was called |
| 375 | + assert mock_dynamic_service_rpc.call_count > 0, ( |
| 376 | + "Dynamic service retrieve RPC was not called" |
| 377 | + ) |
| 378 | + # now get we check which ports were retrieved, we expect all of them |
| 379 | + all_ports = set() |
| 380 | + for call in mock_dynamic_service_rpc.call_args_list: |
| 381 | + retrieved_ports = call[1]["port_keys"] |
| 382 | + all_ports.update(retrieved_ports) |
| 383 | + assert len(all_ports) == expected_ports_retrieved, ( |
| 384 | + f"Expected {expected_ports_retrieved} ports to be retrieved, " |
| 385 | + f"but got {len(all_ports)}: {all_ports}" |
| 386 | + ) |
| 387 | + log_ctx.logger.info( |
| 388 | + "Dynamic service retrieve RPC was called with all expected ports!" |
| 389 | + ) |
| 390 | + |
| 391 | + await _check_for_stability(_check_retrieve_rpc_called, number_of_inputs_linked) |
| 392 | + |
| 393 | + assert sequential_task.done(), "Sequential task did not complete" |
| 394 | + assert not sequential_task.cancelled(), "Sequential task was cancelled unexpectedly" |
0 commit comments