1313from unittest import mock
1414
1515import pytest
16+ import simcore_service_webserver
17+ import simcore_service_webserver .db_listener
18+ import simcore_service_webserver .db_listener ._db_comp_tasks_listening_task
1619from aiohttp .test_utils import TestClient
1720from faker import Faker
1821from models_library .projects import ProjectAtDB
22+ from pytest_mock import MockType
1923from pytest_mock .plugin import MockerFixture
2024from pytest_simcore .helpers .webserver_login import UserInfoDict
2125from simcore_postgres_database .models .comp_pipeline import StateType
@@ -75,6 +79,16 @@ async def with_started_listening_task(client: TestClient) -> AsyncIterator:
7579 yield
7680
7781
82+ @pytest .fixture
83+ async def spied_get_changed_comp_task_row (
84+ mocker : MockerFixture ,
85+ ) -> MockType :
86+ return mocker .spy (
87+ simcore_service_webserver .db_listener ._db_comp_tasks_listening_task , # noqa: SLF001
88+ "_get_changed_comp_task_row" ,
89+ )
90+
91+
7892@dataclass (frozen = True , slots = True )
7993class _CompTaskChangeParams :
8094 update_values : dict [str , Any ]
@@ -151,6 +165,7 @@ async def _assert_listener_triggers(
151165async def test_db_listener_triggers_on_event_with_multiple_tasks (
152166 sqlalchemy_async_engine : AsyncEngine ,
153167 mock_project_subsystem : dict [str , mock .Mock ],
168+ spied_get_changed_comp_task_row : MockType ,
154169 logged_user : UserInfoDict ,
155170 project : Callable [..., Awaitable [ProjectAtDB ]],
156171 pipeline : Callable [..., dict [str , Any ]],
@@ -159,6 +174,7 @@ async def test_db_listener_triggers_on_event_with_multiple_tasks(
159174 params : _CompTaskChangeParams ,
160175 task_class : NodeClass ,
161176 faker : Faker ,
177+ mocker : MockerFixture ,
162178):
163179 some_project = await project (logged_user )
164180 pipeline (project_id = f"{ some_project .uuid } " )
@@ -173,10 +189,21 @@ async def test_db_listener_triggers_on_event_with_multiple_tasks(
173189 for _ in range (3 )
174190 ]
175191 random_task_to_update = tasks [secrets .randbelow (len (tasks ))]
192+ updated_task_id = random_task_to_update ["task_id" ]
193+
176194 async with sqlalchemy_async_engine .begin () as conn :
177195 await conn .execute (
178196 comp_tasks .update ()
179197 .values (** params .update_values )
180- .where (comp_tasks .c .task_id == random_task_to_update [ "task_id" ] )
198+ .where (comp_tasks .c .task_id == updated_task_id )
181199 )
182200 await _assert_listener_triggers (mock_project_subsystem , params .expected_calls )
201+
202+ # Assert the spy was called with the correct task_id
203+ if params .expected_calls :
204+ assert any (
205+ call .args [1 ] == updated_task_id
206+ for call in spied_get_changed_comp_task_row .call_args_list
207+ ), f"_get_changed_comp_task_row was not called with task_id={ updated_task_id } . Calls: { spy_get_changed .call_args_list } "
208+ else :
209+ spied_get_changed_comp_task_row .assert_not_called ()
0 commit comments