Skip to content

Commit 66458cf

Browse files
committed
maybe
1 parent 9521391 commit 66458cf

File tree

1 file changed

+202
-7
lines changed

1 file changed

+202
-7
lines changed

services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py

Lines changed: 202 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import asyncio
1212
import datetime
1313
from collections.abc import AsyncIterator, Awaitable, Callable
14+
from concurrent.futures import ThreadPoolExecutor
1415
from copy import deepcopy
1516
from dataclasses import dataclass
1617
from typing import Any, cast
@@ -245,6 +246,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]
245246
)
246247
# tasks were send to the backend
247248
assert published_project.project.prj_owner is not None
249+
assert isinstance(mocked_dask_client.send_computation_tasks, mock.Mock)
248250
mocked_dask_client.send_computation_tasks.assert_has_calls(
249251
calls=[
250252
mock.call(
@@ -353,9 +355,9 @@ async def _assert_message_received(
353355
return parsed_messages
354356

355357

356-
def _mock_send_computation_tasks(
358+
def _with_mock_send_computation_tasks(
357359
tasks: list[CompTaskAtDB], mocked_dask_client: mock.MagicMock
358-
) -> None:
360+
) -> mock.Mock:
359361
node_id_to_job_id_map = {task.node_id: task.job_id for task in tasks}
360362

361363
async def _send_computation_tasks(
@@ -372,6 +374,7 @@ async def _send_computation_tasks(
372374
] # type: ignore
373375

374376
mocked_dask_client.send_computation_tasks.side_effect = _send_computation_tasks
377+
return mocked_dask_client.send_computation_tasks
375378

376379

377380
async def _trigger_progress_event(
@@ -414,7 +417,7 @@ async def test_proper_pipeline_is_scheduled( # noqa: PLR0915
414417
run_metadata: RunMetadataDict,
415418
):
416419
with_disabled_auto_scheduling.assert_called_once()
417-
_mock_send_computation_tasks(published_project.tasks, mocked_dask_client)
420+
_with_mock_send_computation_tasks(published_project.tasks, mocked_dask_client)
418421

419422
#
420423
# Initiate new pipeline run
@@ -863,6 +866,198 @@ def _parser(x) -> RabbitResourceTrackingMessages:
863866
assert isinstance(messages[1], RabbitResourceTrackingStoppedMessage)
864867

865868

869+
@pytest.fixture
870+
async def with_started_project(
871+
with_disabled_auto_scheduling: mock.Mock,
872+
with_disabled_scheduler_publisher: mock.Mock,
873+
initialized_app: FastAPI,
874+
sqlalchemy_async_engine: AsyncEngine,
875+
publish_project: Callable[[], Awaitable[PublishedProject]],
876+
mocked_dask_client: mock.Mock,
877+
run_metadata: RunMetadataDict,
878+
scheduler_api: BaseCompScheduler,
879+
instrumentation_rabbit_client_parser: mock.AsyncMock,
880+
resource_tracking_rabbit_client_parser: mock.AsyncMock,
881+
) -> RunningProject:
882+
with_disabled_auto_scheduling.assert_called_once()
883+
published_project = await publish_project()
884+
#
885+
# 1. Initiate new pipeline run
886+
#
887+
run_in_db, expected_published_tasks = await _assert_start_pipeline(
888+
initialized_app,
889+
sqlalchemy_async_engine=sqlalchemy_async_engine,
890+
published_project=published_project,
891+
run_metadata=run_metadata,
892+
)
893+
with_disabled_scheduler_publisher.assert_called_once()
894+
895+
#
896+
# 2. This runs the scheduler until the project is started scheduled in the back-end
897+
#
898+
expected_pending_tasks = await _assert_schedule_pipeline_PENDING(
899+
sqlalchemy_async_engine,
900+
published_project,
901+
expected_published_tasks,
902+
mocked_dask_client,
903+
scheduler_api,
904+
)
905+
906+
#
907+
# The dask-worker can take a job when it is PENDING, but the dask scheduler makes
908+
# no difference between PENDING and STARTED
909+
#
910+
exp_started_task = expected_pending_tasks[0]
911+
expected_pending_tasks.remove(exp_started_task)
912+
913+
async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskState]:
914+
return [
915+
(
916+
DaskClientTaskState.PENDING_OR_STARTED
917+
if job_id == exp_started_task.job_id
918+
else DaskClientTaskState.PENDING
919+
)
920+
for job_id in job_ids
921+
]
922+
923+
mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running
924+
await scheduler_api.schedule_pipeline(
925+
user_id=run_in_db.user_id,
926+
project_id=run_in_db.project_uuid,
927+
iteration=run_in_db.iteration,
928+
)
929+
await assert_comp_runs(
930+
sqlalchemy_async_engine,
931+
expected_total=1,
932+
expected_state=RunningState.PENDING,
933+
where_statement=and_(
934+
comp_runs.c.user_id == published_project.project.prj_owner,
935+
comp_runs.c.project_uuid == f"{published_project.project.uuid}",
936+
),
937+
)
938+
await assert_comp_tasks(
939+
sqlalchemy_async_engine,
940+
project_uuid=published_project.project.uuid,
941+
task_ids=[exp_started_task.node_id]
942+
+ [p.node_id for p in expected_pending_tasks],
943+
expected_state=RunningState.PENDING,
944+
expected_progress=None,
945+
)
946+
await assert_comp_tasks(
947+
sqlalchemy_async_engine,
948+
project_uuid=published_project.project.uuid,
949+
task_ids=[p.node_id for p in expected_published_tasks],
950+
expected_state=RunningState.PUBLISHED,
951+
expected_progress=None, # since we bypass the API entrypoint this is correct
952+
)
953+
mocked_dask_client.send_computation_tasks.assert_not_called()
954+
mocked_dask_client.get_tasks_status.assert_called_once_with(
955+
[p.job_id for p in (exp_started_task, *expected_pending_tasks)],
956+
)
957+
mocked_dask_client.get_tasks_status.reset_mock()
958+
mocked_dask_client.get_task_result.assert_not_called()
959+
960+
# -------------------------------------------------------------------------------
961+
# 4. the dask-worker starts processing a task here we simulate a progress event
962+
assert exp_started_task.job_id
963+
assert exp_started_task.project_id
964+
assert exp_started_task.node_id
965+
assert published_project.project.prj_owner
966+
await _trigger_progress_event(
967+
scheduler_api,
968+
job_id=exp_started_task.job_id,
969+
user_id=published_project.project.prj_owner,
970+
project_id=exp_started_task.project_id,
971+
node_id=exp_started_task.node_id,
972+
)
973+
974+
await scheduler_api.schedule_pipeline(
975+
user_id=run_in_db.user_id,
976+
project_id=run_in_db.project_uuid,
977+
iteration=run_in_db.iteration,
978+
)
979+
# comp_run, the comp_task switch to STARTED
980+
await assert_comp_runs(
981+
sqlalchemy_async_engine,
982+
expected_total=1,
983+
expected_state=RunningState.STARTED,
984+
where_statement=and_(
985+
comp_runs.c.user_id == published_project.project.prj_owner,
986+
comp_runs.c.project_uuid == f"{published_project.project.uuid}",
987+
),
988+
)
989+
await assert_comp_tasks(
990+
sqlalchemy_async_engine,
991+
project_uuid=published_project.project.uuid,
992+
task_ids=[exp_started_task.node_id],
993+
expected_state=RunningState.STARTED,
994+
expected_progress=0,
995+
)
996+
await assert_comp_tasks(
997+
sqlalchemy_async_engine,
998+
project_uuid=published_project.project.uuid,
999+
task_ids=[p.node_id for p in expected_pending_tasks],
1000+
expected_state=RunningState.PENDING,
1001+
expected_progress=None,
1002+
)
1003+
await assert_comp_tasks(
1004+
sqlalchemy_async_engine,
1005+
project_uuid=published_project.project.uuid,
1006+
task_ids=[p.node_id for p in expected_published_tasks],
1007+
expected_state=RunningState.PUBLISHED,
1008+
expected_progress=None,
1009+
)
1010+
mocked_dask_client.send_computation_tasks.assert_not_called()
1011+
mocked_dask_client.get_tasks_status.assert_called_once_with(
1012+
[p.job_id for p in (exp_started_task, *expected_pending_tasks)],
1013+
)
1014+
mocked_dask_client.get_tasks_status.reset_mock()
1015+
mocked_dask_client.get_task_result.assert_not_called()
1016+
# check the metrics are properly published
1017+
messages = await _assert_message_received(
1018+
instrumentation_rabbit_client_parser,
1019+
1,
1020+
InstrumentationRabbitMessage.model_validate_json,
1021+
)
1022+
assert messages[0].metrics == "service_started"
1023+
assert messages[0].service_uuid == exp_started_task.node_id
1024+
1025+
# check the RUT messages are properly published
1026+
messages = await _assert_message_received(
1027+
resource_tracking_rabbit_client_parser,
1028+
1,
1029+
RabbitResourceTrackingStartedMessage.model_validate_json,
1030+
)
1031+
assert messages[0].node_id == exp_started_task.node_id
1032+
1033+
return RunningProject(
1034+
published_project.project,
1035+
published_project.pipeline,
1036+
published_project.tasks,
1037+
runs=run_in_db,
1038+
)
1039+
1040+
1041+
async def test_completed_task_triggers_new_scheduling_task(
1042+
with_disabled_scheduler_publisher: mock.Mock,
1043+
with_started_project: RunningProject,
1044+
initialized_app: FastAPI,
1045+
mocked_dask_client: mock.MagicMock,
1046+
scheduler_api: BaseCompScheduler,
1047+
sqlalchemy_async_engine: AsyncEngine,
1048+
mocker: MockerFixture,
1049+
):
1050+
"""When a pipeline job completes, the Dask backend provides a callback
1051+
that runs in a separate thread. We use that callback to ask the
1052+
director-v2 computational scheduler manager to ask for a new schedule
1053+
After fiddling in distributed source code, here is a similar way to trigger that callback
1054+
"""
1055+
with ThreadPoolExecutor(
1056+
max_workers=1, thread_name_prefix="pytest-callback-thread"
1057+
) as executor:
1058+
...
1059+
1060+
8661061
async def test_broken_pipeline_configuration_is_not_scheduled_and_aborted(
8671062
with_disabled_auto_scheduling: mock.Mock,
8681063
with_disabled_scheduler_publisher: mock.Mock,
@@ -938,15 +1133,15 @@ async def test_task_progress_triggers(
9381133
mocked_clean_task_output_and_log_files_if_invalid: mock.Mock,
9391134
run_metadata: RunMetadataDict,
9401135
):
941-
_mock_send_computation_tasks(published_project.tasks, mocked_dask_client)
1136+
_with_mock_send_computation_tasks(published_project.tasks, mocked_dask_client)
9421137
_run_in_db, expected_published_tasks = await _assert_start_pipeline(
9431138
initialized_app,
9441139
sqlalchemy_async_engine=sqlalchemy_async_engine,
9451140
published_project=published_project,
9461141
run_metadata=run_metadata,
9471142
)
9481143
# -------------------------------------------------------------------------------
949-
# 1. first run will move comp_tasks to PENDING so the worker can take them
1144+
# 1. first run will move comp_tasks to PENDING so the dask-worker can take them
9501145
expected_pending_tasks = await _assert_schedule_pipeline_PENDING(
9511146
sqlalchemy_async_engine,
9521147
published_project,
@@ -1400,15 +1595,15 @@ async def test_running_pipeline_triggers_heartbeat(
14001595
resource_tracking_rabbit_client_parser: mock.AsyncMock,
14011596
run_metadata: RunMetadataDict,
14021597
):
1403-
_mock_send_computation_tasks(published_project.tasks, mocked_dask_client)
1598+
_with_mock_send_computation_tasks(published_project.tasks, mocked_dask_client)
14041599
run_in_db, expected_published_tasks = await _assert_start_pipeline(
14051600
initialized_app,
14061601
sqlalchemy_async_engine=sqlalchemy_async_engine,
14071602
published_project=published_project,
14081603
run_metadata=run_metadata,
14091604
)
14101605
# -------------------------------------------------------------------------------
1411-
# 1. first run will move comp_tasks to PENDING so the worker can take them
1606+
# 1. first run will move comp_tasks to PENDING so the dask-worker can take them
14121607
expected_pending_tasks = await _assert_schedule_pipeline_PENDING(
14131608
sqlalchemy_async_engine,
14141609
published_project,

0 commit comments

Comments
 (0)