Skip to content

Commit cd48297

Browse files
committed
added new test
1 parent 22374df commit cd48297

File tree

4 files changed

+278
-47
lines changed

4 files changed

+278
-47
lines changed

packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import pytest
1111
import sqlalchemy as sa
1212
from faker import Faker
13+
from models_library.products import ProductName
1314
from models_library.projects import ProjectAtDB, ProjectID
1415
from models_library.projects_nodes_io import NodeID
1516
from simcore_postgres_database.models.comp_pipeline import StateType, comp_pipeline
1617
from simcore_postgres_database.models.comp_tasks import comp_tasks
1718
from simcore_postgres_database.models.projects import ProjectType, projects
19+
from simcore_postgres_database.models.projects_to_products import projects_to_products
1820
from simcore_postgres_database.models.users import UserRole, UserStatus, users
1921
from simcore_postgres_database.utils_projects_nodes import (
2022
ProjectNodeCreate,
@@ -64,7 +66,7 @@ def creator(**user_kwargs) -> dict[str, Any]:
6466

6567
@pytest.fixture
6668
async def project(
67-
sqlalchemy_async_engine: AsyncEngine, faker: Faker
69+
sqlalchemy_async_engine: AsyncEngine, faker: Faker, product_name: ProductName
6870
) -> AsyncIterator[Callable[..., Awaitable[ProjectAtDB]]]:
6971
created_project_ids: list[str] = []
7072

@@ -112,6 +114,12 @@ async def creator(
112114
for node_id in inserted_project.workbench
113115
],
114116
)
117+
await con.execute(
118+
projects_to_products.insert().values(
119+
project_uuid=f"{inserted_project.uuid}",
120+
product_name=product_name,
121+
)
122+
)
115123
print(f"--> created {inserted_project=}")
116124
created_project_ids.append(f"{inserted_project.uuid}")
117125
return inserted_project

services/web/server/src/simcore_service_webserver/db_listener/_db_comp_tasks_listening_task.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from models_library.projects_state import RunningState
1919
from pydantic.types import PositiveInt
2020
from servicelib.background_task import periodic_task
21+
from servicelib.logging_utils import log_catch
2122
from simcore_postgres_database.models.comp_tasks import comp_tasks
2223
from simcore_postgres_database.webserver_models import DB_CHANNEL_NAME, projects
2324
from sqlalchemy.sql import select
@@ -73,56 +74,57 @@ async def _get_changed_comp_task_row(
7374
async def _handle_db_notification(
7475
app: web.Application, payload: CompTaskNotificationPayload, conn: SAConnection
7576
) -> None:
76-
try:
77-
the_project_owner = await _get_project_owner(conn, payload.project_id)
78-
changed_row = await _get_changed_comp_task_row(conn, payload.task_id)
79-
if not changed_row:
77+
with log_catch(_logger, reraise=False):
78+
try:
79+
the_project_owner = await _get_project_owner(conn, payload.project_id)
80+
changed_row = await _get_changed_comp_task_row(conn, payload.task_id)
81+
if not changed_row:
82+
_logger.warning(
83+
"No comp_tasks row found for project_id=%s node_id=%s",
84+
payload.project_id,
85+
payload.node_id,
86+
)
87+
return
88+
89+
if any(f in payload.changes for f in ["outputs", "run_hash"]):
90+
await update_node_outputs(
91+
app,
92+
the_project_owner,
93+
payload.project_id,
94+
payload.node_id,
95+
changed_row.outputs,
96+
changed_row.run_hash,
97+
node_errors=changed_row.errors,
98+
ui_changed_keys=None,
99+
)
100+
101+
if "state" in payload.changes and (changed_row.state is not None):
102+
await _update_project_state(
103+
app,
104+
the_project_owner,
105+
payload.project_id,
106+
payload.node_id,
107+
convert_state_from_db(changed_row.state),
108+
node_errors=changed_row.errors,
109+
)
110+
111+
except exceptions.ProjectNotFoundError as exc:
80112
_logger.warning(
81-
"No comp_tasks row found for project_id=%s node_id=%s",
82-
payload.project_id,
83-
payload.node_id,
113+
"Project %s was not found and cannot be updated. Maybe was it deleted?",
114+
exc.project_uuid,
84115
)
85-
return
86-
87-
if any(f in payload.changes for f in ["outputs", "run_hash"]):
88-
await update_node_outputs(
89-
app,
90-
the_project_owner,
91-
payload.project_id,
92-
payload.node_id,
93-
changed_row.outputs,
94-
changed_row.run_hash,
95-
node_errors=changed_row.errors,
96-
ui_changed_keys=None,
116+
except exceptions.ProjectOwnerNotFoundError as exc:
117+
_logger.warning(
118+
"Project owner of project %s could not be found, is the project valid?",
119+
exc.project_uuid,
97120
)
98-
99-
if "state" in payload.changes and (changed_row.state is not None):
100-
await _update_project_state(
101-
app,
102-
the_project_owner,
103-
payload.project_id,
104-
payload.node_id,
105-
convert_state_from_db(changed_row.state),
106-
node_errors=changed_row.errors,
121+
except exceptions.NodeNotFoundError as exc:
122+
_logger.warning(
123+
"Node %s of project %s not found and cannot be updated. Maybe was it deleted?",
124+
exc.node_uuid,
125+
exc.project_uuid,
107126
)
108127

109-
except exceptions.ProjectNotFoundError as exc:
110-
_logger.warning(
111-
"Project %s was not found and cannot be updated. Maybe was it deleted?",
112-
exc.project_uuid,
113-
)
114-
except exceptions.ProjectOwnerNotFoundError as exc:
115-
_logger.warning(
116-
"Project owner of project %s could not be found, is the project valid?",
117-
exc.project_uuid,
118-
)
119-
except exceptions.NodeNotFoundError as exc:
120-
_logger.warning(
121-
"Node %s of project %s not found and cannot be updated. Maybe was it deleted?",
122-
exc.node_uuid,
123-
exc.project_uuid,
124-
)
125-
126128

127129
async def _listen(app: web.Application) -> NoReturn:
128130
listen_query = f"LISTEN {DB_CHANNEL_NAME};"
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"e8eae2cd-ae0f-4ba6-ae0b-86eeadf99b42": {
3+
"key": "simcore/services/dynamic/jupyter-math",
4+
"version": "3.0.5",
5+
"label": "JupyterLab Math (Python+Octave)",
6+
"inputs": {},
7+
"inputsRequired": [],
8+
"inputNodes": []
9+
},
10+
"f7d6dc1e-a6dc-44e1-9588-a2f4b05d3d9c": {
11+
"key": "simcore/services/dynamic/jupyter-math",
12+
"version": "3.0.5",
13+
"label": "JupyterLab Math (Python+Octave)_2",
14+
"inputs": {
15+
"input_1": {
16+
"nodeUuid": "e8eae2cd-ae0f-4ba6-ae0b-86eeadf99b42",
17+
"output": "output_1"
18+
},
19+
"input_2": {
20+
"nodeUuid": "e8eae2cd-ae0f-4ba6-ae0b-86eeadf99b42",
21+
"output": "output_2"
22+
},
23+
"input_3": {
24+
"nodeUuid": "e8eae2cd-ae0f-4ba6-ae0b-86eeadf99b42",
25+
"output": "output_3"
26+
},
27+
"input_4": {
28+
"nodeUuid": "e8eae2cd-ae0f-4ba6-ae0b-86eeadf99b42",
29+
"output": "output_4"
30+
}
31+
},
32+
"inputsRequired": [],
33+
"inputNodes": [
34+
"e8eae2cd-ae0f-4ba6-ae0b-86eeadf99b42"
35+
]
36+
}
37+
}

services/web/server/tests/unit/with_dbs/04/notifications/test_notifications__db_comp_tasks_listening_task.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
# pylint:disable=too-many-arguments
66
# pylint:disable=protected-access
77

8+
from ast import Assert
9+
import asyncio
10+
from datetime import timedelta
811
import json
912
import logging
1013
import secrets
@@ -14,6 +17,10 @@
1417
from unittest import mock
1518

1619
import pytest
20+
from tenacity import stop_after_attempt
21+
from common_library.async_tools import delayed_start
22+
from models_library.projects_nodes import InputsDict
23+
from pytest_simcore.helpers.logging_tools import log_context
1724
import simcore_service_webserver
1825
import simcore_service_webserver.db_listener
1926
import simcore_service_webserver.db_listener._db_comp_tasks_listening_task
@@ -36,6 +43,8 @@
3643
from tenacity.stop import stop_after_delay
3744
from tenacity.wait import wait_fixed
3845

46+
from simcore_service_webserver.projects.models import ProjectDict
47+
3948
logger = logging.getLogger(__name__)
4049

4150

@@ -205,6 +214,181 @@ async def test_db_listener_triggers_on_event_with_multiple_tasks(
205214
assert any(
206215
call.args[1] == updated_task_id
207216
for call in spied_get_changed_comp_task_row.call_args_list
208-
), f"_get_changed_comp_task_row was not called with task_id={updated_task_id}. Calls: {spied_get_changed_comp_task_row.call_args_list}"
217+
), (
218+
f"_get_changed_comp_task_row was not called with task_id={updated_task_id}. Calls: {spied_get_changed_comp_task_row.call_args_list}"
219+
)
209220
else:
210221
spied_get_changed_comp_task_row.assert_not_called()
222+
223+
224+
from pathlib import Path
225+
226+
227+
@pytest.fixture
228+
def fake_2connected_jupyterlabs_workbench(tests_data_dir: Path) -> dict[str, Any]:
229+
fpath = tests_data_dir / "workbench_2connected_jupyterlabs.json"
230+
assert fpath.exists()
231+
return json.loads(fpath.read_text())
232+
233+
234+
@pytest.fixture
235+
async def mock_dynamic_service_rpc(
236+
mocker: MockerFixture,
237+
) -> mock.AsyncMock:
238+
"""
239+
Mocks the dynamic service RPC calls to avoid actual service calls during tests.
240+
"""
241+
return mocker.patch(
242+
"servicelib.rabbitmq.rpc_interfaces.dynamic_scheduler.services.retrieve_inputs",
243+
autospec=True,
244+
)
245+
246+
247+
async def _check_for_stability(
248+
function: Callable[..., Awaitable[None]], *args, **kwargs
249+
) -> None:
250+
async for attempt in AsyncRetrying(
251+
stop=stop_after_attempt(5),
252+
wait=wait_fixed(1),
253+
retry=retry_if_exception_type(),
254+
reraise=True,
255+
):
256+
with attempt: # noqa: SIM117
257+
with log_context(
258+
logging.INFO,
259+
msg=f"check stability of {function.__name__} {attempt.retry_state.retry_object.statistics}",
260+
) as log_ctx:
261+
await function(*args, **kwargs)
262+
log_ctx.logger.info(
263+
"stable for %s...", attempt.retry_state.seconds_since_start
264+
)
265+
266+
267+
@pytest.mark.testit
268+
@pytest.mark.parametrize("user_role", [UserRole.USER])
269+
async def test_db_listener_upgrades_projects_row_correctly(
270+
with_started_listening_task: None,
271+
mock_dynamic_service_rpc: mock.AsyncMock,
272+
sqlalchemy_async_engine: AsyncEngine,
273+
logged_user: UserInfoDict,
274+
project: Callable[..., Awaitable[ProjectAtDB]],
275+
fake_2connected_jupyterlabs_workbench: dict[str, Any],
276+
pipeline: Callable[..., dict[str, Any]],
277+
comp_task: Callable[..., dict[str, Any]],
278+
spied_get_changed_comp_task_row: MockType,
279+
faker: Faker,
280+
):
281+
some_project = await project(
282+
logged_user, workbench=fake_2connected_jupyterlabs_workbench
283+
)
284+
285+
# create the corresponding comp_task entries for the project workbench
286+
pipeline(project_id=f"{some_project.uuid}")
287+
tasks = [
288+
comp_task(
289+
project_id=f"{some_project.uuid}",
290+
node_id=node_id,
291+
outputs=node_data.get("outputs", {}),
292+
node_class=NodeClass.INTERACTIVE
293+
if "dynamic" in node_data["key"]
294+
else NodeClass.COMPUTATIONAL,
295+
inputs=node_data.get("inputs", InputsDict()),
296+
)
297+
for node_id, node_data in fake_2connected_jupyterlabs_workbench.items()
298+
]
299+
assert len(tasks) == 2, "Expected two tasks for the two JupyterLab nodes"
300+
first_jupyter_task = tasks[0]
301+
second_jupyter_task = tasks[1]
302+
assert len(second_jupyter_task["inputs"]) > 0, (
303+
"Expected inputs for the second JupyterLab task"
304+
)
305+
number_of_inputs_linked = len(second_jupyter_task["inputs"])
306+
307+
# simulate a concurrent change in all the outputs of first jupyterlab
308+
async def _update_first_jupyter_task_output(
309+
port_index: int, data: dict[str, Any]
310+
) -> None:
311+
with log_context(logging.INFO, msg=f"Updating output {port_index + 1}"):
312+
async with sqlalchemy_async_engine.begin() as conn:
313+
# For JSON columns, we need to use jsonb_set or fetch-modify-update
314+
# Since it's JSON (not JSONB), let's use the safer fetch-modify approach
315+
# Use SELECT FOR UPDATE to lock the row for concurrent access
316+
result = await conn.execute(
317+
comp_tasks.select()
318+
.with_only_columns([comp_tasks.c.outputs])
319+
.where(comp_tasks.c.task_id == first_jupyter_task["task_id"])
320+
.with_for_update()
321+
)
322+
row = result.first()
323+
current_outputs = row[0] if row and row[0] else {}
324+
325+
# Update/add the new key while preserving existing keys
326+
current_outputs[f"output_{port_index + 1}"] = data
327+
328+
# Write back the updated outputs
329+
await conn.execute(
330+
comp_tasks.update()
331+
.values(outputs=current_outputs)
332+
.where(comp_tasks.c.task_id == first_jupyter_task["task_id"])
333+
)
334+
335+
# await asyncio.gather(
336+
# *(
337+
# _update_first_jupyter_task_output(i, {"data": i})
338+
# for i in range(number_of_inputs_linked)
339+
# )
340+
# )
341+
342+
@delayed_start(timedelta(seconds=2))
343+
async def _change_outputs_sequentially(sleep: float = 0.1) -> None:
344+
"""
345+
Sequentially updates the outputs of the second JupyterLab task to trigger the dynamic service RPC.
346+
"""
347+
for i in range(number_of_inputs_linked):
348+
await _update_first_jupyter_task_output(i, {"data": i})
349+
await asyncio.sleep(sleep)
350+
351+
# this runs in a task
352+
sequential_task = asyncio.create_task(_change_outputs_sequentially(5))
353+
assert sequential_task is not None, "Failed to create the sequential task"
354+
355+
async def _check_retrieve_rpc_called(expected_ports_retrieved: int) -> None:
356+
async for attempt in AsyncRetrying(
357+
stop=stop_after_delay(60),
358+
wait=wait_fixed(1),
359+
retry=retry_if_exception_type(AssertionError),
360+
reraise=True,
361+
):
362+
with attempt: # noqa: SIM117
363+
with log_context(
364+
logging.INFO,
365+
msg=f"Checking if dynamic service retrieve RPC was called and "
366+
f"all expected ports were retrieved {expected_ports_retrieved} "
367+
f"times, {attempt.retry_state.retry_object.statistics}",
368+
) as log_ctx:
369+
if mock_dynamic_service_rpc.call_count > 0:
370+
log_ctx.logger.info(
371+
"call arguments: %s",
372+
mock_dynamic_service_rpc.call_args_list,
373+
)
374+
# Assert that the dynamic service RPC was called
375+
assert mock_dynamic_service_rpc.call_count > 0, (
376+
"Dynamic service retrieve RPC was not called"
377+
)
378+
# now get we check which ports were retrieved, we expect all of them
379+
all_ports = set()
380+
for call in mock_dynamic_service_rpc.call_args_list:
381+
retrieved_ports = call[1]["port_keys"]
382+
all_ports.update(retrieved_ports)
383+
assert len(all_ports) == expected_ports_retrieved, (
384+
f"Expected {expected_ports_retrieved} ports to be retrieved, "
385+
f"but got {len(all_ports)}: {all_ports}"
386+
)
387+
log_ctx.logger.info(
388+
"Dynamic service retrieve RPC was called with all expected ports!"
389+
)
390+
391+
await _check_for_stability(_check_retrieve_rpc_called, number_of_inputs_linked)
392+
393+
assert sequential_task.done(), "Sequential task did not complete"
394+
assert not sequential_task.cancelled(), "Sequential task was cancelled unexpectedly"

0 commit comments

Comments
 (0)