diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index ebd8c6030a65..a5bb20c31e42 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -23,6 +23,8 @@ This document provides guidelines and best practices for using GitHub Copilot in - ensure we use `fastapi` >0.100 compatible code - use f-string formatting - Only add comments in function if strictly necessary +- use relative imports +- imports should be at top of the file ### Json serialization diff --git a/packages/postgres-database/src/simcore_postgres_database/migration/versions/278daef7e99d_remove_whole_row_in_payload.py b/packages/postgres-database/src/simcore_postgres_database/migration/versions/278daef7e99d_remove_whole_row_in_payload.py new file mode 100644 index 000000000000..bd8f730a4b26 --- /dev/null +++ b/packages/postgres-database/src/simcore_postgres_database/migration/versions/278daef7e99d_remove_whole_row_in_payload.py @@ -0,0 +1,134 @@ +"""remove whole row in payload + +Revision ID: 278daef7e99d +Revises: 4e7d8719855b +Create Date: 2025-05-22 21:22:11.084001+00:00 + +""" + +from typing import Final + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "278daef7e99d" +down_revision = "4e7d8719855b" +branch_labels = None +depends_on = None + +DB_PROCEDURE_NAME: Final[str] = "notify_comp_tasks_changed" +DB_TRIGGER_NAME: Final[str] = f"{DB_PROCEDURE_NAME}_event" +DB_CHANNEL_NAME: Final[str] = "comp_tasks_output_events" + + +def upgrade(): + drop_trigger = sa.DDL( + f""" +DROP TRIGGER IF EXISTS {DB_TRIGGER_NAME} on comp_tasks; +""" + ) + + task_output_changed_procedure = sa.DDL( + f""" +CREATE OR REPLACE FUNCTION {DB_PROCEDURE_NAME}() RETURNS TRIGGER AS $$ + DECLARE + record RECORD; + payload JSON; + changes JSONB; + BEGIN + IF (TG_OP = 'DELETE') THEN + record = OLD; + ELSE + record = NEW; + END IF; + + SELECT jsonb_agg(pre.key ORDER BY pre.key) INTO changes + FROM jsonb_each(to_jsonb(OLD)) AS pre, jsonb_each(to_jsonb(NEW)) AS post + WHERE pre.key = post.key AND pre.value IS DISTINCT FROM post.value; + + payload = json_build_object( + 'table', TG_TABLE_NAME, + 'changes', changes, + 'action', TG_OP, + 'task_id', record.task_id, + 'project_id', record.project_id, + 'node_id', record.node_id + ); + + PERFORM pg_notify('{DB_CHANNEL_NAME}', payload::text); + + RETURN NULL; + END; +$$ LANGUAGE plpgsql; +""" + ) + + task_output_changed_trigger = sa.DDL( + f""" +DROP TRIGGER IF EXISTS {DB_TRIGGER_NAME} on comp_tasks; +CREATE TRIGGER {DB_TRIGGER_NAME} +AFTER UPDATE OF outputs,state ON comp_tasks + FOR EACH ROW + WHEN ((OLD.outputs::jsonb IS DISTINCT FROM NEW.outputs::jsonb OR OLD.state IS DISTINCT FROM NEW.state)) + EXECUTE PROCEDURE {DB_PROCEDURE_NAME}(); +""" + ) + + op.execute(drop_trigger) + op.execute(task_output_changed_procedure) + op.execute(task_output_changed_trigger) + + +def downgrade(): + drop_trigger = sa.DDL( + f""" +DROP TRIGGER IF EXISTS {DB_TRIGGER_NAME} on comp_tasks; +""" + ) + + task_output_changed_procedure = sa.DDL( + f""" +CREATE OR REPLACE FUNCTION {DB_PROCEDURE_NAME}() RETURNS TRIGGER AS $$ + DECLARE + record RECORD; + payload JSON; + changes JSONB; + BEGIN + IF (TG_OP = 'DELETE') THEN + record = OLD; + ELSE + record = NEW; + END IF; + + SELECT jsonb_agg(pre.key ORDER BY pre.key) INTO changes + FROM jsonb_each(to_jsonb(OLD)) AS pre, jsonb_each(to_jsonb(NEW)) AS post + WHERE pre.key = post.key AND pre.value IS DISTINCT FROM post.value; + + payload = json_build_object('table', TG_TABLE_NAME, + 'changes', changes, + 'action', TG_OP, + 'data', row_to_json(record)); + + PERFORM pg_notify('{DB_CHANNEL_NAME}', payload::text); + + RETURN NULL; + END; +$$ LANGUAGE plpgsql; +""" + ) + + task_output_changed_trigger = sa.DDL( + f""" +DROP TRIGGER IF EXISTS {DB_TRIGGER_NAME} on comp_tasks; +CREATE TRIGGER {DB_TRIGGER_NAME} +AFTER UPDATE OF outputs,state ON comp_tasks + FOR EACH ROW + WHEN ((OLD.outputs::jsonb IS DISTINCT FROM NEW.outputs::jsonb OR OLD.state IS DISTINCT FROM NEW.state)) + EXECUTE PROCEDURE {DB_PROCEDURE_NAME}(); +""" + ) + + op.execute(drop_trigger) + op.execute(task_output_changed_procedure) + op.execute(task_output_changed_trigger) diff --git a/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py b/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py index 15c3ddbacd1f..3af09bfaa01a 100644 --- a/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py +++ b/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py @@ -152,10 +152,14 @@ class NodeClass(enum.Enum): FROM jsonb_each(to_jsonb(OLD)) AS pre, jsonb_each(to_jsonb(NEW)) AS post WHERE pre.key = post.key AND pre.value IS DISTINCT FROM post.value; - payload = json_build_object('table', TG_TABLE_NAME, - 'changes', changes, - 'action', TG_OP, - 'data', row_to_json(record)); + payload = json_build_object( + 'table', TG_TABLE_NAME, + 'changes', changes, + 'action', TG_OP, + 'task_id', record.task_id, + 'project_id', record.project_id, + 'node_id', record.node_id + ); PERFORM pg_notify('{DB_CHANNEL_NAME}', payload::text); diff --git a/packages/postgres-database/tests/test_comp_tasks.py b/packages/postgres-database/tests/test_comp_tasks.py index 490c0cebc986..4759e074dc93 100644 --- a/packages/postgres-database/tests/test_comp_tasks.py +++ b/packages/postgres-database/tests/test_comp_tasks.py @@ -9,7 +9,6 @@ import pytest from aiopg.sa.engine import Engine, SAConnection -from aiopg.sa.result import RowProxy from simcore_postgres_database.models.comp_pipeline import StateType from simcore_postgres_database.models.comp_tasks import ( DB_CHANNEL_NAME, @@ -20,7 +19,7 @@ @pytest.fixture() -async def db_connection(aiopg_engine: Engine) -> SAConnection: +async def db_connection(aiopg_engine: Engine) -> AsyncIterator[SAConnection]: async with aiopg_engine.acquire() as conn: yield conn @@ -31,6 +30,7 @@ async def db_notification_queue( ) -> AsyncIterator[asyncio.Queue]: listen_query = f"LISTEN {DB_CHANNEL_NAME};" await db_connection.execute(listen_query) + assert db_connection.connection notifications_queue: asyncio.Queue = db_connection.connection.notifies assert notifications_queue.empty() yield notifications_queue @@ -51,7 +51,8 @@ async def task( .values(outputs=json.dumps({}), node_class=task_class) .returning(literal_column("*")) ) - row: RowProxy = await result.fetchone() + row = await result.fetchone() + assert row task = dict(row) assert ( @@ -73,8 +74,15 @@ async def _assert_notification_queue_status( assert msg, "notification msg from postgres is empty!" task_data = json.loads(msg.payload) - - for k in ["table", "changes", "action", "data"]: + expected_keys = [ + "task_id", + "project_id", + "node_id", + "changes", + "action", + "table", + ] + for k in expected_keys: assert k in task_data, f"invalid structure, expected [{k}] in {task_data}" tasks.append(task_data) @@ -110,9 +118,15 @@ async def test_listen_query( ) tasks = await _assert_notification_queue_status(db_notification_queue, 1) assert tasks[0]["changes"] == ["modified", "outputs", "state"] + assert tasks[0]["action"] == "UPDATE" + assert tasks[0]["table"] == "comp_tasks" + assert tasks[0]["task_id"] == task["task_id"] + assert tasks[0]["project_id"] == task["project_id"] + assert tasks[0]["node_id"] == task["node_id"] + assert ( - tasks[0]["data"]["outputs"] == updated_output - ), f"the data received from the database is {tasks[0]}, expected new output is {updated_output}" + "data" not in tasks[0] + ), "data is not expected in the notification payload anymore" # setting the exact same data twice triggers only ONCE updated_output = {"some new stuff": "it is newer"} @@ -120,10 +134,11 @@ async def test_listen_query( await _update_comp_task_with(db_connection, task, outputs=updated_output) tasks = await _assert_notification_queue_status(db_notification_queue, 1) assert tasks[0]["changes"] == ["modified", "outputs"] - assert ( - tasks[0]["data"]["outputs"] == updated_output - ), f"the data received from the database is {tasks[0]}, expected new output is {updated_output}" - + assert tasks[0]["action"] == "UPDATE" + assert tasks[0]["table"] == "comp_tasks" + assert tasks[0]["task_id"] == task["task_id"] + assert tasks[0]["project_id"] == task["project_id"] + assert tasks[0]["node_id"] == task["node_id"] # updating a number of times with different stuff comes out in FIFO order NUM_CALLS = 20 update_outputs = [] @@ -135,7 +150,10 @@ async def test_listen_query( tasks = await _assert_notification_queue_status(db_notification_queue, NUM_CALLS) for n, output in enumerate(update_outputs): + assert output assert tasks[n]["changes"] == ["modified", "outputs"] - assert ( - tasks[n]["data"]["outputs"] == output - ), f"the data received from the database is {tasks[n]}, expected new output is {output}" + assert tasks[n]["action"] == "UPDATE" + assert tasks[n]["table"] == "comp_tasks" + assert tasks[n]["task_id"] == task["task_id"] + assert tasks[n]["project_id"] == task["project_id"] + assert tasks[n]["node_id"] == task["node_id"] diff --git a/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py b/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py index 6159e3d72202..15f28daf3162 100644 --- a/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py +++ b/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py @@ -10,9 +10,10 @@ import pytest import sqlalchemy as sa from faker import Faker -from models_library.projects import ProjectAtDB +from models_library.projects import ProjectAtDB, ProjectID from models_library.projects_nodes_io import NodeID from simcore_postgres_database.models.comp_pipeline import StateType, comp_pipeline +from simcore_postgres_database.models.comp_tasks import comp_tasks from simcore_postgres_database.models.projects import ProjectType, projects from simcore_postgres_database.models.users import UserRole, UserStatus, users from simcore_postgres_database.utils_projects_nodes import ( @@ -142,9 +143,8 @@ def creator(**pipeline_kwargs) -> dict[str, Any]: .values(**pipeline_config) .returning(sa.literal_column("*")) ) - new_pipeline = result.first() - assert new_pipeline - new_pipeline = dict(new_pipeline) + row = result.one() + new_pipeline = row._asdict() created_pipeline_ids.append(new_pipeline["project_id"]) return new_pipeline @@ -157,3 +157,29 @@ def creator(**pipeline_kwargs) -> dict[str, Any]: comp_pipeline.c.project_id.in_(created_pipeline_ids) ) ) + + +@pytest.fixture +def comp_task(postgres_db: sa.engine.Engine) -> Iterator[Callable[..., dict[str, Any]]]: + created_task_ids: list[int] = [] + + def creator(project_id: ProjectID, **task_kwargs) -> dict[str, Any]: + task_config = {"project_id": f"{project_id}"} | task_kwargs + with postgres_db.connect() as conn: + result = conn.execute( + comp_tasks.insert() + .values(**task_config) + .returning(sa.literal_column("*")) + ) + row = result.one() + new_task = row._asdict() + created_task_ids.append(new_task["task_id"]) + return new_task + + yield creator + + # cleanup + with postgres_db.connect() as conn: + conn.execute( + comp_tasks.delete().where(comp_tasks.c.task_id.in_(created_task_ids)) + ) diff --git a/services/web/server/src/simcore_service_webserver/db_listener/_db_comp_tasks_listening_task.py b/services/web/server/src/simcore_service_webserver/db_listener/_db_comp_tasks_listening_task.py index ea6ee0c2b62b..2c72213b094a 100644 --- a/services/web/server/src/simcore_service_webserver/db_listener/_db_comp_tasks_listening_task.py +++ b/services/web/server/src/simcore_service_webserver/db_listener/_db_comp_tasks_listening_task.py @@ -4,36 +4,39 @@ """ import asyncio +import datetime import logging from collections.abc import AsyncIterator -from contextlib import suppress -from dataclasses import dataclass -from typing import Final, NoReturn +from typing import Final, NoReturn, cast from aiohttp import web -from aiopg.sa import Engine from aiopg.sa.connection import SAConnection -from common_library.json_serialization import json_loads +from aiopg.sa.result import RowProxy from models_library.errors import ErrorDict from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID from models_library.projects_state import RunningState from pydantic.types import PositiveInt +from servicelib.background_task import periodic_task +from simcore_postgres_database.models.comp_tasks import comp_tasks from simcore_postgres_database.webserver_models import DB_CHANNEL_NAME, projects from sqlalchemy.sql import select from ..db.plugin import get_database_engine from ..projects import _projects_service, exceptions from ..projects.nodes_utils import update_node_outputs +from ._models import CompTaskNotificationPayload from ._utils import convert_state_from_db _LISTENING_TASK_BASE_SLEEPING_TIME_S: Final[int] = 1 _logger = logging.getLogger(__name__) -async def _get_project_owner(conn: SAConnection, project_uuid: str) -> PositiveInt: +async def _get_project_owner( + conn: SAConnection, project_uuid: ProjectID +) -> PositiveInt: the_project_owner: PositiveInt | None = await conn.scalar( - select(projects.c.prj_owner).where(projects.c.uuid == project_uuid) + select(projects.c.prj_owner).where(projects.c.uuid == f"{project_uuid}") ) if not the_project_owner: raise exceptions.ProjectOwnerNotFoundError(project_uuid=project_uuid) @@ -58,61 +61,49 @@ async def _update_project_state( await _projects_service.notify_project_state_update(app, project) -@dataclass(frozen=True) -class _CompTaskNotificationPayload: - action: str - data: dict - changes: dict - table: str +async def _get_changed_comp_task_row( + conn: SAConnection, task_id: PositiveInt +) -> RowProxy | None: + result = await conn.execute( + select(comp_tasks).where(comp_tasks.c.task_id == task_id) + ) + return cast(RowProxy | None, await result.fetchone()) async def _handle_db_notification( - app: web.Application, payload: _CompTaskNotificationPayload, conn: SAConnection + app: web.Application, payload: CompTaskNotificationPayload, conn: SAConnection ) -> None: - task_data = payload.data - task_changes = payload.changes - - project_uuid = task_data.get("project_id", None) - node_uuid = task_data.get("node_id", None) - if any(x is None for x in [project_uuid, node_uuid]): - _logger.warning( - "comp_tasks row is corrupted. TIP: please check DB entry containing '%s'", - f"{task_data=}", - ) - return - - assert project_uuid # nosec - assert node_uuid # nosec - try: - # NOTE: we need someone with the rights to modify that project. the owner is one. - # find the user(s) linked to that project - the_project_owner = await _get_project_owner(conn, project_uuid) - - if any(f in task_changes for f in ["outputs", "run_hash"]): - new_outputs = task_data.get("outputs", {}) - new_run_hash = task_data.get("run_hash", None) + the_project_owner = await _get_project_owner(conn, payload.project_id) + changed_row = await _get_changed_comp_task_row(conn, payload.task_id) + if not changed_row: + _logger.warning( + "No comp_tasks row found for project_id=%s node_id=%s", + payload.project_id, + payload.node_id, + ) + return + if any(f in payload.changes for f in ["outputs", "run_hash"]): await update_node_outputs( app, the_project_owner, - ProjectID(project_uuid), - NodeID(node_uuid), - new_outputs, - new_run_hash, - node_errors=task_data.get("errors", None), + payload.project_id, + payload.node_id, + changed_row.outputs, + changed_row.run_hash, + node_errors=changed_row.errors, ui_changed_keys=None, ) - if "state" in task_changes: - new_state = convert_state_from_db(task_data["state"]) + if "state" in payload.changes and (changed_row.state is not None): await _update_project_state( app, the_project_owner, - ProjectID(project_uuid), - NodeID(node_uuid), - new_state, - node_errors=task_data.get("errors", None), + payload.project_id, + payload.node_id, + convert_state_from_db(changed_row.state), + node_errors=changed_row.errors, ) except exceptions.ProjectNotFoundError as exc: @@ -133,9 +124,9 @@ async def _handle_db_notification( ) -async def _listen(app: web.Application, db_engine: Engine) -> NoReturn: +async def _listen(app: web.Application) -> NoReturn: listen_query = f"LISTEN {DB_CHANNEL_NAME};" - + db_engine = get_database_engine(app) async with db_engine.acquire() as conn: assert conn.connection # nosec await conn.execute(listen_query) @@ -151,45 +142,18 @@ async def _listen(app: web.Application, db_engine: Engine) -> NoReturn: await asyncio.sleep(_LISTENING_TASK_BASE_SLEEPING_TIME_S) continue notification = conn.connection.notifies.get_nowait() - # get the data and the info on what changed - payload = _CompTaskNotificationPayload(**json_loads(notification.payload)) + payload = CompTaskNotificationPayload.model_validate_json( + notification.payload + ) _logger.debug("received update from database: %s", f"{payload=}") await _handle_db_notification(app, payload, conn) -async def _comp_tasks_listening_task(app: web.Application) -> None: - _logger.info("starting comp_task db listening task...") - while True: - try: - # create a special connection here - db_engine = get_database_engine(app) - _logger.info("listening to comp_task events...") - await _listen(app, db_engine) - except asyncio.CancelledError: # noqa: PERF203 - # we are closing the app.. - _logger.info("cancelled comp_tasks events") - raise - except Exception: # pylint: disable=broad-except - _logger.exception( - "caught unhandled comp_task db listening task exception, restarting...", - ) - # wait a bit and try restart the task - await asyncio.sleep(3) - - async def create_comp_tasks_listening_task(app: web.Application) -> AsyncIterator[None]: - task = asyncio.create_task( - _comp_tasks_listening_task(app), name="computation db listener" - ) - _logger.debug("comp_tasks db listening task created %s", f"{task=}") - - yield - - _logger.debug("cancelling comp_tasks db listening %s task...", f"{task=}") - task.cancel() - _logger.debug("waiting for comp_tasks db listening %s to stop", f"{task=}") - with suppress(asyncio.CancelledError): - await task - _logger.debug( - "waiting for comp_tasks db listening %s to stop completed", f"{task=}" - ) + async with periodic_task( + _listen, + interval=datetime.timedelta(seconds=_LISTENING_TASK_BASE_SLEEPING_TIME_S), + task_name="computation db listener", + app=app, + ): + yield diff --git a/services/web/server/src/simcore_service_webserver/db_listener/_models.py b/services/web/server/src/simcore_service_webserver/db_listener/_models.py new file mode 100644 index 000000000000..e1f830820946 --- /dev/null +++ b/services/web/server/src/simcore_service_webserver/db_listener/_models.py @@ -0,0 +1,16 @@ +from typing import TypeAlias + +from models_library.projects import ProjectID +from models_library.projects_nodes_io import NodeID +from pydantic import BaseModel + +_DB_KEY: TypeAlias = str + + +class CompTaskNotificationPayload(BaseModel): + action: str + changes: list[_DB_KEY] + table: str + task_id: int + project_id: ProjectID + node_id: NodeID diff --git a/services/web/server/src/simcore_service_webserver/db_listener/plugin.py b/services/web/server/src/simcore_service_webserver/db_listener/plugin.py index f047491d3a40..423e307f3de2 100644 --- a/services/web/server/src/simcore_service_webserver/db_listener/plugin.py +++ b/services/web/server/src/simcore_service_webserver/db_listener/plugin.py @@ -10,7 +10,6 @@ from ..db.plugin import setup_db from ..projects._projects_repository_legacy import setup_projects_db -from ..rabbitmq import setup_rabbitmq from ..socketio.plugin import setup_socketio from ._db_comp_tasks_listening_task import create_comp_tasks_listening_task @@ -24,7 +23,6 @@ logger=_logger, ) def setup_db_listener(app: web.Application): - setup_rabbitmq(app) setup_socketio(app) setup_projects_db(app) # Creates a task to listen to comp_task pg-db's table events diff --git a/services/web/server/tests/unit/with_dbs/04/notifications/test_notifications__db_comp_tasks_listening_task.py b/services/web/server/tests/unit/with_dbs/04/notifications/test_notifications__db_comp_tasks_listening_task.py index c17c5aa1aa69..1d787e86b937 100644 --- a/services/web/server/tests/unit/with_dbs/04/notifications/test_notifications__db_comp_tasks_listening_task.py +++ b/services/web/server/tests/unit/with_dbs/04/notifications/test_notifications__db_comp_tasks_listening_task.py @@ -3,28 +3,33 @@ # pylint:disable=redefined-outer-name # pylint:disable=no-value-for-parameter # pylint:disable=too-many-arguments +# pylint:disable=protected-access import json import logging -from collections.abc import AsyncIterator, Awaitable, Callable, Iterator +import secrets +from collections.abc import AsyncIterator, Awaitable, Callable +from dataclasses import dataclass from typing import Any from unittest import mock -import aiopg.sa import pytest -import sqlalchemy as sa +import simcore_service_webserver +import simcore_service_webserver.db_listener +import simcore_service_webserver.db_listener._db_comp_tasks_listening_task from aiohttp.test_utils import TestClient from faker import Faker -from models_library.projects import ProjectAtDB, ProjectID +from models_library.projects import ProjectAtDB +from pytest_mock import MockType from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.webserver_login import UserInfoDict -from servicelib.aiohttp.application_keys import APP_AIOPG_ENGINE_KEY from simcore_postgres_database.models.comp_pipeline import StateType from simcore_postgres_database.models.comp_tasks import NodeClass, comp_tasks from simcore_postgres_database.models.users import UserRole from simcore_service_webserver.db_listener._db_comp_tasks_listening_task import ( create_comp_tasks_listening_task, ) +from sqlalchemy.ext.asyncio import AsyncEngine from tenacity.asyncio import AsyncRetrying from tenacity.before_sleep import before_sleep_log from tenacity.retry import retry_if_exception_type @@ -35,9 +40,7 @@ @pytest.fixture -async def mock_project_subsystem( - mocker: MockerFixture, -) -> AsyncIterator[dict[str, mock.MagicMock]]: +async def mock_project_subsystem(mocker: MockerFixture) -> dict[str, mock.Mock]: mocked_project_calls = {} mocked_project_calls["update_node_outputs"] = mocker.patch( @@ -45,22 +48,32 @@ async def mock_project_subsystem( return_value="", ) - mocked_project_calls["_get_project_owner"] = mocker.patch( - "simcore_service_webserver.db_listener._db_comp_tasks_listening_task._get_project_owner", - return_value="", + mocked_project_calls["_update_project_state.update_project_node_state"] = ( + mocker.patch( + "simcore_service_webserver.projects._projects_service.update_project_node_state", + autospec=True, + ) ) - mocked_project_calls["_update_project_state"] = mocker.patch( - "simcore_service_webserver.db_listener._db_comp_tasks_listening_task._update_project_state", - return_value="", + + mocked_project_calls["_update_project_state.notify_project_node_update"] = ( + mocker.patch( + "simcore_service_webserver.projects._projects_service.notify_project_node_update", + autospec=True, + ) + ) + + mocked_project_calls["_update_project_state.notify_project_state_update"] = ( + mocker.patch( + "simcore_service_webserver.projects._projects_service.notify_project_state_update", + autospec=True, + ) ) return mocked_project_calls @pytest.fixture -async def comp_task_listening_task( - mock_project_subsystem: dict, client: TestClient -) -> AsyncIterator: +async def with_started_listening_task(client: TestClient) -> AsyncIterator: assert client.app async for _comp_task in create_comp_tasks_listening_task(client.app): # first call creates the task, second call cleans it @@ -68,107 +81,130 @@ async def comp_task_listening_task( @pytest.fixture -def comp_task( - postgres_db: sa.engine.Engine, -) -> Iterator[Callable[..., dict[str, Any]]]: - created_task_ids: list[int] = [] - - def creator(project_id: ProjectID, **task_kwargs) -> dict[str, Any]: - task_config = {"project_id": f"{project_id}"} | task_kwargs - with postgres_db.connect() as conn: - result = conn.execute( - comp_tasks.insert() - .values(**task_config) - .returning(sa.literal_column("*")) - ) - new_task = result.first() - assert new_task - new_task = dict(new_task) - created_task_ids.append(new_task["task_id"]) - return new_task - - yield creator - - # cleanup - with postgres_db.connect() as conn: - conn.execute( - comp_tasks.delete().where(comp_tasks.c.task_id.in_(created_task_ids)) - ) +async def spied_get_changed_comp_task_row( + mocker: MockerFixture, +) -> MockType: + return mocker.spy( + simcore_service_webserver.db_listener._db_comp_tasks_listening_task, # noqa: SLF001 + "_get_changed_comp_task_row", + ) + + +@dataclass(frozen=True, slots=True) +class _CompTaskChangeParams: + update_values: dict[str, Any] + expected_calls: list[str] + + +async def _assert_listener_triggers( + mock_project_subsystem: dict[str, mock.Mock], expected_calls: list[str] +) -> None: + for call_name, mocked_call in mock_project_subsystem.items(): + if call_name in expected_calls: + async for attempt in AsyncRetrying( + wait=wait_fixed(1), + stop=stop_after_delay(10), + retry=retry_if_exception_type(AssertionError), + before_sleep=before_sleep_log(logger, logging.INFO), + reraise=True, + ): + with attempt: + mocked_call.assert_called_once() + + else: + mocked_call.assert_not_called() @pytest.mark.parametrize( "task_class", [NodeClass.COMPUTATIONAL, NodeClass.INTERACTIVE, NodeClass.FRONTEND] ) @pytest.mark.parametrize( - "update_values, expected_calls", + "params", [ pytest.param( - { - "outputs": {"some new stuff": "it is new"}, - }, - ["_get_project_owner", "update_node_outputs"], + _CompTaskChangeParams( + { + "outputs": {"some new stuff": "it is new"}, + }, + ["update_node_outputs"], + ), id="new output shall trigger", ), pytest.param( - {"state": StateType.ABORTED}, - ["_get_project_owner", "_update_project_state"], + _CompTaskChangeParams( + {"state": StateType.ABORTED}, + [ + "_update_project_state.update_project_node_state", + "_update_project_state.notify_project_node_update", + "_update_project_state.notify_project_state_update", + ], + ), id="new state shall trigger", ), pytest.param( - {"outputs": {"some new stuff": "it is new"}, "state": StateType.ABORTED}, - ["_get_project_owner", "update_node_outputs", "_update_project_state"], + _CompTaskChangeParams( + { + "outputs": {"some new stuff": "it is new"}, + "state": StateType.ABORTED, + }, + [ + "update_node_outputs", + "_update_project_state.update_project_node_state", + "_update_project_state.notify_project_node_update", + "_update_project_state.notify_project_state_update", + ], + ), id="new output and state shall double trigger", ), pytest.param( - {"inputs": {"should not trigger": "right?"}}, - [], + _CompTaskChangeParams({"inputs": {"should not trigger": "right?"}}, []), id="no new output or state shall not trigger", ), ], ) @pytest.mark.parametrize("user_role", [UserRole.USER]) -async def test_listen_comp_tasks_task( - mock_project_subsystem: dict, +async def test_db_listener_triggers_on_event_with_multiple_tasks( + sqlalchemy_async_engine: AsyncEngine, + mock_project_subsystem: dict[str, mock.Mock], + spied_get_changed_comp_task_row: MockType, logged_user: UserInfoDict, project: Callable[..., Awaitable[ProjectAtDB]], pipeline: Callable[..., dict[str, Any]], comp_task: Callable[..., dict[str, Any]], - comp_task_listening_task: None, - client, - update_values: dict[str, Any], - expected_calls: list[str], + with_started_listening_task: None, + params: _CompTaskChangeParams, task_class: NodeClass, faker: Faker, + mocker: MockerFixture, ): - db_engine: aiopg.sa.Engine = client.app[APP_AIOPG_ENGINE_KEY] some_project = await project(logged_user) pipeline(project_id=f"{some_project.uuid}") - task = comp_task( - project_id=f"{some_project.uuid}", - node_id=faker.uuid4(), - outputs=json.dumps({}), - node_class=task_class, - ) - async with db_engine.acquire() as conn: - # let's update some values + # Create 3 tasks with different node_ids + tasks = [ + comp_task( + project_id=f"{some_project.uuid}", + node_id=faker.uuid4(), + outputs=json.dumps({}), + node_class=task_class, + ) + for _ in range(3) + ] + random_task_to_update = tasks[secrets.randbelow(len(tasks))] + updated_task_id = random_task_to_update["task_id"] + + async with sqlalchemy_async_engine.begin() as conn: await conn.execute( comp_tasks.update() - .values(**update_values) - .where(comp_tasks.c.task_id == task["task_id"]) + .values(**params.update_values) + .where(comp_tasks.c.task_id == updated_task_id) ) - - # tests whether listener gets executed - for call_name, mocked_call in mock_project_subsystem.items(): - if call_name in expected_calls: - async for attempt in AsyncRetrying( - wait=wait_fixed(1), - stop=stop_after_delay(10), - retry=retry_if_exception_type(AssertionError), - before_sleep=before_sleep_log(logger, logging.INFO), - reraise=True, - ): - with attempt: - mocked_call.assert_awaited() - - else: - mocked_call.assert_not_called() + await _assert_listener_triggers(mock_project_subsystem, params.expected_calls) + + # Assert the spy was called with the correct task_id + if params.expected_calls: + assert any( + call.args[1] == updated_task_id + for call in spied_get_changed_comp_task_row.call_args_list + ), f"_get_changed_comp_task_row was not called with task_id={updated_task_id}. Calls: {spied_get_changed_comp_task_row.call_args_list}" + else: + spied_get_changed_comp_task_row.assert_not_called()