1111import asyncio
1212import datetime
1313from collections .abc import AsyncIterator , Awaitable , Callable
14- from concurrent .futures import ThreadPoolExecutor
1514from copy import deepcopy
1615from dataclasses import dataclass
1716from typing import Any , cast
@@ -195,13 +194,13 @@ async def _assert_start_pipeline(
195194 return runs [0 ], exp_published_tasks
196195
197196
198- async def _assert_schedule_pipeline_PENDING ( # noqa: N802
197+ async def _assert_publish_in_dask_backend (
199198 sqlalchemy_async_engine : AsyncEngine ,
200199 published_project : PublishedProject ,
201200 published_tasks : list [CompTaskAtDB ],
202201 mocked_dask_client : mock .MagicMock ,
203202 scheduler : BaseCompScheduler ,
204- ) -> list [CompTaskAtDB ]:
203+ ) -> tuple [ list [CompTaskAtDB ], dict [ NodeID , Callable [[], None ]] ]:
205204 expected_pending_tasks = [
206205 published_tasks [1 ],
207206 published_tasks [3 ],
@@ -247,6 +246,8 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]
247246 # tasks were send to the backend
248247 assert published_project .project .prj_owner is not None
249248 assert isinstance (mocked_dask_client .send_computation_tasks , mock .Mock )
249+ assert isinstance (mocked_dask_client .get_tasks_status , mock .Mock )
250+ assert isinstance (mocked_dask_client .get_task_result , mock .Mock )
250251 mocked_dask_client .send_computation_tasks .assert_has_calls (
251252 calls = [
252253 mock .call (
@@ -262,6 +263,12 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]
262263 ],
263264 any_order = True ,
264265 )
266+ task_to_callback_mapping = {
267+ task .node_id : mocked_dask_client .send_computation_tasks .call_args_list [
268+ i
269+ ].kwargs ["callback" ]
270+ for i , task in enumerate (expected_pending_tasks )
271+ }
265272 mocked_dask_client .send_computation_tasks .reset_mock ()
266273 mocked_dask_client .get_tasks_status .assert_not_called ()
267274 mocked_dask_client .get_task_result .assert_not_called ()
@@ -298,7 +305,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]
298305 )
299306 mocked_dask_client .get_tasks_status .reset_mock ()
300307 mocked_dask_client .get_task_result .assert_not_called ()
301- return expected_pending_tasks
308+ return expected_pending_tasks , task_to_callback_mapping
302309
303310
304311@pytest .fixture
@@ -432,7 +439,7 @@ async def test_proper_pipeline_is_scheduled( # noqa: PLR0915
432439
433440 # -------------------------------------------------------------------------------
434441 # 1. first run will move comp_tasks to PENDING so the dask-worker can take them
435- expected_pending_tasks = await _assert_schedule_pipeline_PENDING (
442+ expected_pending_tasks , _ = await _assert_publish_in_dask_backend (
436443 sqlalchemy_async_engine ,
437444 published_project ,
438445 expected_published_tasks ,
@@ -895,7 +902,10 @@ async def with_started_project(
895902 #
896903 # 2. This runs the scheduler until the project is started scheduled in the back-end
897904 #
898- expected_pending_tasks = await _assert_schedule_pipeline_PENDING (
905+ (
906+ expected_pending_tasks ,
907+ task_to_callback_mapping ,
908+ ) = await _assert_publish_in_dask_backend (
899909 sqlalchemy_async_engine ,
900910 published_project ,
901911 expected_published_tasks ,
@@ -920,6 +930,9 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
920930 for job_id in job_ids
921931 ]
922932
933+ assert isinstance (mocked_dask_client .get_tasks_status , mock .Mock )
934+ assert isinstance (mocked_dask_client .send_computation_tasks , mock .Mock )
935+ assert isinstance (mocked_dask_client .get_task_result , mock .Mock )
923936 mocked_dask_client .get_tasks_status .side_effect = _return_1st_task_running
924937 await scheduler_api .schedule_pipeline (
925938 user_id = run_in_db .user_id ,
@@ -977,30 +990,32 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
977990 iteration = run_in_db .iteration ,
978991 )
979992 # comp_run, the comp_task switch to STARTED
980- await assert_comp_runs (
981- sqlalchemy_async_engine ,
982- expected_total = 1 ,
983- expected_state = RunningState .STARTED ,
984- where_statement = and_ (
985- comp_runs .c .user_id == published_project .project .prj_owner ,
986- comp_runs .c .project_uuid == f"{ published_project .project .uuid } " ,
987- ),
988- )
989- await assert_comp_tasks (
993+ run_in_db = (
994+ await assert_comp_runs (
995+ sqlalchemy_async_engine ,
996+ expected_total = 1 ,
997+ expected_state = RunningState .STARTED ,
998+ where_statement = and_ (
999+ comp_runs .c .user_id == published_project .project .prj_owner ,
1000+ comp_runs .c .project_uuid == f"{ published_project .project .uuid } " ,
1001+ ),
1002+ )
1003+ )[0 ]
1004+ tasks_in_db = await assert_comp_tasks (
9901005 sqlalchemy_async_engine ,
9911006 project_uuid = published_project .project .uuid ,
9921007 task_ids = [exp_started_task .node_id ],
9931008 expected_state = RunningState .STARTED ,
9941009 expected_progress = 0 ,
9951010 )
996- await assert_comp_tasks (
1011+ tasks_in_db += await assert_comp_tasks (
9971012 sqlalchemy_async_engine ,
9981013 project_uuid = published_project .project .uuid ,
9991014 task_ids = [p .node_id for p in expected_pending_tasks ],
10001015 expected_state = RunningState .PENDING ,
10011016 expected_progress = None ,
10021017 )
1003- await assert_comp_tasks (
1018+ tasks_in_db += await assert_comp_tasks (
10041019 sqlalchemy_async_engine ,
10051020 project_uuid = published_project .project .uuid ,
10061021 task_ids = [p .node_id for p in expected_published_tasks ],
@@ -1033,29 +1048,40 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
10331048 return RunningProject (
10341049 published_project .project ,
10351050 published_project .pipeline ,
1036- published_project . tasks ,
1051+ tasks_in_db ,
10371052 runs = run_in_db ,
1053+ task_to_callback_mapping = task_to_callback_mapping ,
1054+ )
1055+
1056+
1057+ @pytest .fixture
1058+ def mocked_worker_publisher (mocker : MockerFixture ) -> mock .Mock :
1059+ return mocker .patch (
1060+ "simcore_service_director_v2.modules.comp_scheduler._scheduler_base.request_pipeline_scheduling" ,
1061+ autospec = True ,
10381062 )
10391063
10401064
10411065async def test_completed_task_triggers_new_scheduling_task (
1042- with_disabled_scheduler_publisher : mock .Mock ,
1066+ mocked_worker_publisher : mock .Mock ,
10431067 with_started_project : RunningProject ,
1044- initialized_app : FastAPI ,
1045- mocked_dask_client : mock .MagicMock ,
1046- scheduler_api : BaseCompScheduler ,
1047- sqlalchemy_async_engine : AsyncEngine ,
1048- mocker : MockerFixture ,
10491068):
10501069 """When a pipeline job completes, the Dask backend provides a callback
10511070 that runs in a separate thread. We use that callback to ask the
10521071 director-v2 computational scheduler manager to ask for a new schedule
10531072 After fiddling in distributed source code, here is a similar way to trigger that callback
10541073 """
1055- with ThreadPoolExecutor (
1056- max_workers = 1 , thread_name_prefix = "pytest-callback-thread"
1057- ) as executor :
1058- ...
1074+ completed_node_id = with_started_project .tasks [0 ].node_id
1075+ callback = with_started_project .task_to_callback_mapping [completed_node_id ]
1076+ await asyncio .to_thread (callback )
1077+
1078+ mocked_worker_publisher .assert_called_once_with (
1079+ mock .ANY ,
1080+ mock .ANY ,
1081+ user_id = with_started_project .runs .user_id ,
1082+ project_id = with_started_project .runs .project_uuid ,
1083+ iteration = with_started_project .runs .iteration ,
1084+ )
10591085
10601086
10611087async def test_broken_pipeline_configuration_is_not_scheduled_and_aborted (
@@ -1142,7 +1168,7 @@ async def test_task_progress_triggers(
11421168 )
11431169 # -------------------------------------------------------------------------------
11441170 # 1. first run will move comp_tasks to PENDING so the dask-worker can take them
1145- expected_pending_tasks = await _assert_schedule_pipeline_PENDING (
1171+ expected_pending_tasks , _ = await _assert_publish_in_dask_backend (
11461172 sqlalchemy_async_engine ,
11471173 published_project ,
11481174 expected_published_tasks ,
@@ -1604,7 +1630,7 @@ async def test_running_pipeline_triggers_heartbeat(
16041630 )
16051631 # -------------------------------------------------------------------------------
16061632 # 1. first run will move comp_tasks to PENDING so the dask-worker can take them
1607- expected_pending_tasks = await _assert_schedule_pipeline_PENDING (
1633+ expected_pending_tasks , _ = await _assert_publish_in_dask_backend (
16081634 sqlalchemy_async_engine ,
16091635 published_project ,
16101636 expected_published_tasks ,
0 commit comments