1111import asyncio
1212import datetime
1313from collections .abc import AsyncIterator , Awaitable , Callable
14+ from concurrent .futures import ThreadPoolExecutor
1415from copy import deepcopy
1516from dataclasses import dataclass
1617from typing import Any , cast
@@ -245,6 +246,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]
245246 )
246247 # tasks were send to the backend
247248 assert published_project .project .prj_owner is not None
249+ assert isinstance (mocked_dask_client .send_computation_tasks , mock .Mock )
248250 mocked_dask_client .send_computation_tasks .assert_has_calls (
249251 calls = [
250252 mock .call (
@@ -353,9 +355,9 @@ async def _assert_message_received(
353355 return parsed_messages
354356
355357
356- def _mock_send_computation_tasks (
358+ def _with_mock_send_computation_tasks (
357359 tasks : list [CompTaskAtDB ], mocked_dask_client : mock .MagicMock
358- ) -> None :
360+ ) -> mock . Mock :
359361 node_id_to_job_id_map = {task .node_id : task .job_id for task in tasks }
360362
361363 async def _send_computation_tasks (
@@ -372,6 +374,7 @@ async def _send_computation_tasks(
372374 ] # type: ignore
373375
374376 mocked_dask_client .send_computation_tasks .side_effect = _send_computation_tasks
377+ return mocked_dask_client .send_computation_tasks
375378
376379
377380async def _trigger_progress_event (
@@ -414,7 +417,7 @@ async def test_proper_pipeline_is_scheduled( # noqa: PLR0915
414417 run_metadata : RunMetadataDict ,
415418):
416419 with_disabled_auto_scheduling .assert_called_once ()
417- _mock_send_computation_tasks (published_project .tasks , mocked_dask_client )
420+ _with_mock_send_computation_tasks (published_project .tasks , mocked_dask_client )
418421
419422 #
420423 # Initiate new pipeline run
@@ -863,6 +866,198 @@ def _parser(x) -> RabbitResourceTrackingMessages:
863866 assert isinstance (messages [1 ], RabbitResourceTrackingStoppedMessage )
864867
865868
869+ @pytest .fixture
870+ async def with_started_project (
871+ with_disabled_auto_scheduling : mock .Mock ,
872+ with_disabled_scheduler_publisher : mock .Mock ,
873+ initialized_app : FastAPI ,
874+ sqlalchemy_async_engine : AsyncEngine ,
875+ publish_project : Callable [[], Awaitable [PublishedProject ]],
876+ mocked_dask_client : mock .Mock ,
877+ run_metadata : RunMetadataDict ,
878+ scheduler_api : BaseCompScheduler ,
879+ instrumentation_rabbit_client_parser : mock .AsyncMock ,
880+ resource_tracking_rabbit_client_parser : mock .AsyncMock ,
881+ ) -> RunningProject :
882+ with_disabled_auto_scheduling .assert_called_once ()
883+ published_project = await publish_project ()
884+ #
885+ # 1. Initiate new pipeline run
886+ #
887+ run_in_db , expected_published_tasks = await _assert_start_pipeline (
888+ initialized_app ,
889+ sqlalchemy_async_engine = sqlalchemy_async_engine ,
890+ published_project = published_project ,
891+ run_metadata = run_metadata ,
892+ )
893+ with_disabled_scheduler_publisher .assert_called_once ()
894+
895+ #
896+ # 2. This runs the scheduler until the project is started scheduled in the back-end
897+ #
898+ expected_pending_tasks = await _assert_schedule_pipeline_PENDING (
899+ sqlalchemy_async_engine ,
900+ published_project ,
901+ expected_published_tasks ,
902+ mocked_dask_client ,
903+ scheduler_api ,
904+ )
905+
906+ #
907+ # The dask-worker can take a job when it is PENDING, but the dask scheduler makes
908+ # no difference between PENDING and STARTED
909+ #
910+ exp_started_task = expected_pending_tasks [0 ]
911+ expected_pending_tasks .remove (exp_started_task )
912+
913+ async def _return_1st_task_running (job_ids : list [str ]) -> list [DaskClientTaskState ]:
914+ return [
915+ (
916+ DaskClientTaskState .PENDING_OR_STARTED
917+ if job_id == exp_started_task .job_id
918+ else DaskClientTaskState .PENDING
919+ )
920+ for job_id in job_ids
921+ ]
922+
923+ mocked_dask_client .get_tasks_status .side_effect = _return_1st_task_running
924+ await scheduler_api .schedule_pipeline (
925+ user_id = run_in_db .user_id ,
926+ project_id = run_in_db .project_uuid ,
927+ iteration = run_in_db .iteration ,
928+ )
929+ await assert_comp_runs (
930+ sqlalchemy_async_engine ,
931+ expected_total = 1 ,
932+ expected_state = RunningState .PENDING ,
933+ where_statement = and_ (
934+ comp_runs .c .user_id == published_project .project .prj_owner ,
935+ comp_runs .c .project_uuid == f"{ published_project .project .uuid } " ,
936+ ),
937+ )
938+ await assert_comp_tasks (
939+ sqlalchemy_async_engine ,
940+ project_uuid = published_project .project .uuid ,
941+ task_ids = [exp_started_task .node_id ]
942+ + [p .node_id for p in expected_pending_tasks ],
943+ expected_state = RunningState .PENDING ,
944+ expected_progress = None ,
945+ )
946+ await assert_comp_tasks (
947+ sqlalchemy_async_engine ,
948+ project_uuid = published_project .project .uuid ,
949+ task_ids = [p .node_id for p in expected_published_tasks ],
950+ expected_state = RunningState .PUBLISHED ,
951+ expected_progress = None , # since we bypass the API entrypoint this is correct
952+ )
953+ mocked_dask_client .send_computation_tasks .assert_not_called ()
954+ mocked_dask_client .get_tasks_status .assert_called_once_with (
955+ [p .job_id for p in (exp_started_task , * expected_pending_tasks )],
956+ )
957+ mocked_dask_client .get_tasks_status .reset_mock ()
958+ mocked_dask_client .get_task_result .assert_not_called ()
959+
960+ # -------------------------------------------------------------------------------
961+ # 4. the dask-worker starts processing a task here we simulate a progress event
962+ assert exp_started_task .job_id
963+ assert exp_started_task .project_id
964+ assert exp_started_task .node_id
965+ assert published_project .project .prj_owner
966+ await _trigger_progress_event (
967+ scheduler_api ,
968+ job_id = exp_started_task .job_id ,
969+ user_id = published_project .project .prj_owner ,
970+ project_id = exp_started_task .project_id ,
971+ node_id = exp_started_task .node_id ,
972+ )
973+
974+ await scheduler_api .schedule_pipeline (
975+ user_id = run_in_db .user_id ,
976+ project_id = run_in_db .project_uuid ,
977+ iteration = run_in_db .iteration ,
978+ )
979+ # comp_run, the comp_task switch to STARTED
980+ await assert_comp_runs (
981+ sqlalchemy_async_engine ,
982+ expected_total = 1 ,
983+ expected_state = RunningState .STARTED ,
984+ where_statement = and_ (
985+ comp_runs .c .user_id == published_project .project .prj_owner ,
986+ comp_runs .c .project_uuid == f"{ published_project .project .uuid } " ,
987+ ),
988+ )
989+ await assert_comp_tasks (
990+ sqlalchemy_async_engine ,
991+ project_uuid = published_project .project .uuid ,
992+ task_ids = [exp_started_task .node_id ],
993+ expected_state = RunningState .STARTED ,
994+ expected_progress = 0 ,
995+ )
996+ await assert_comp_tasks (
997+ sqlalchemy_async_engine ,
998+ project_uuid = published_project .project .uuid ,
999+ task_ids = [p .node_id for p in expected_pending_tasks ],
1000+ expected_state = RunningState .PENDING ,
1001+ expected_progress = None ,
1002+ )
1003+ await assert_comp_tasks (
1004+ sqlalchemy_async_engine ,
1005+ project_uuid = published_project .project .uuid ,
1006+ task_ids = [p .node_id for p in expected_published_tasks ],
1007+ expected_state = RunningState .PUBLISHED ,
1008+ expected_progress = None ,
1009+ )
1010+ mocked_dask_client .send_computation_tasks .assert_not_called ()
1011+ mocked_dask_client .get_tasks_status .assert_called_once_with (
1012+ [p .job_id for p in (exp_started_task , * expected_pending_tasks )],
1013+ )
1014+ mocked_dask_client .get_tasks_status .reset_mock ()
1015+ mocked_dask_client .get_task_result .assert_not_called ()
1016+ # check the metrics are properly published
1017+ messages = await _assert_message_received (
1018+ instrumentation_rabbit_client_parser ,
1019+ 1 ,
1020+ InstrumentationRabbitMessage .model_validate_json ,
1021+ )
1022+ assert messages [0 ].metrics == "service_started"
1023+ assert messages [0 ].service_uuid == exp_started_task .node_id
1024+
1025+ # check the RUT messages are properly published
1026+ messages = await _assert_message_received (
1027+ resource_tracking_rabbit_client_parser ,
1028+ 1 ,
1029+ RabbitResourceTrackingStartedMessage .model_validate_json ,
1030+ )
1031+ assert messages [0 ].node_id == exp_started_task .node_id
1032+
1033+ return RunningProject (
1034+ published_project .project ,
1035+ published_project .pipeline ,
1036+ published_project .tasks ,
1037+ runs = run_in_db ,
1038+ )
1039+
1040+
1041+ async def test_completed_task_triggers_new_scheduling_task (
1042+ with_disabled_scheduler_publisher : mock .Mock ,
1043+ with_started_project : RunningProject ,
1044+ initialized_app : FastAPI ,
1045+ mocked_dask_client : mock .MagicMock ,
1046+ scheduler_api : BaseCompScheduler ,
1047+ sqlalchemy_async_engine : AsyncEngine ,
1048+ mocker : MockerFixture ,
1049+ ):
1050+ """When a pipeline job completes, the Dask backend provides a callback
1051+ that runs in a separate thread. We use that callback to ask the
1052+ director-v2 computational scheduler manager to ask for a new schedule
1053+ After fiddling in distributed source code, here is a similar way to trigger that callback
1054+ """
1055+ with ThreadPoolExecutor (
1056+ max_workers = 1 , thread_name_prefix = "pytest-callback-thread"
1057+ ) as executor :
1058+ ...
1059+
1060+
8661061async def test_broken_pipeline_configuration_is_not_scheduled_and_aborted (
8671062 with_disabled_auto_scheduling : mock .Mock ,
8681063 with_disabled_scheduler_publisher : mock .Mock ,
@@ -938,15 +1133,15 @@ async def test_task_progress_triggers(
9381133 mocked_clean_task_output_and_log_files_if_invalid : mock .Mock ,
9391134 run_metadata : RunMetadataDict ,
9401135):
941- _mock_send_computation_tasks (published_project .tasks , mocked_dask_client )
1136+ _with_mock_send_computation_tasks (published_project .tasks , mocked_dask_client )
9421137 _run_in_db , expected_published_tasks = await _assert_start_pipeline (
9431138 initialized_app ,
9441139 sqlalchemy_async_engine = sqlalchemy_async_engine ,
9451140 published_project = published_project ,
9461141 run_metadata = run_metadata ,
9471142 )
9481143 # -------------------------------------------------------------------------------
949- # 1. first run will move comp_tasks to PENDING so the worker can take them
1144+ # 1. first run will move comp_tasks to PENDING so the dask- worker can take them
9501145 expected_pending_tasks = await _assert_schedule_pipeline_PENDING (
9511146 sqlalchemy_async_engine ,
9521147 published_project ,
@@ -1400,15 +1595,15 @@ async def test_running_pipeline_triggers_heartbeat(
14001595 resource_tracking_rabbit_client_parser : mock .AsyncMock ,
14011596 run_metadata : RunMetadataDict ,
14021597):
1403- _mock_send_computation_tasks (published_project .tasks , mocked_dask_client )
1598+ _with_mock_send_computation_tasks (published_project .tasks , mocked_dask_client )
14041599 run_in_db , expected_published_tasks = await _assert_start_pipeline (
14051600 initialized_app ,
14061601 sqlalchemy_async_engine = sqlalchemy_async_engine ,
14071602 published_project = published_project ,
14081603 run_metadata = run_metadata ,
14091604 )
14101605 # -------------------------------------------------------------------------------
1411- # 1. first run will move comp_tasks to PENDING so the worker can take them
1606+ # 1. first run will move comp_tasks to PENDING so the dask- worker can take them
14121607 expected_pending_tasks = await _assert_schedule_pipeline_PENDING (
14131608 sqlalchemy_async_engine ,
14141609 published_project ,
0 commit comments