Skip to content

Commit 10069ac

Browse files
committed
test the callback mechanism
1 parent 39ee59f commit 10069ac

File tree

1 file changed

+57
-31
lines changed

1 file changed

+57
-31
lines changed

services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import asyncio
1212
import datetime
1313
from collections.abc import AsyncIterator, Awaitable, Callable
14-
from concurrent.futures import ThreadPoolExecutor
1514
from copy import deepcopy
1615
from dataclasses import dataclass
1716
from typing import Any, cast
@@ -195,13 +194,13 @@ async def _assert_start_pipeline(
195194
return runs[0], exp_published_tasks
196195

197196

198-
async def _assert_schedule_pipeline_PENDING( # noqa: N802
197+
async def _assert_publish_in_dask_backend(
199198
sqlalchemy_async_engine: AsyncEngine,
200199
published_project: PublishedProject,
201200
published_tasks: list[CompTaskAtDB],
202201
mocked_dask_client: mock.MagicMock,
203202
scheduler: BaseCompScheduler,
204-
) -> list[CompTaskAtDB]:
203+
) -> tuple[list[CompTaskAtDB], dict[NodeID, Callable[[], None]]]:
205204
expected_pending_tasks = [
206205
published_tasks[1],
207206
published_tasks[3],
@@ -247,6 +246,8 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]
247246
# tasks were send to the backend
248247
assert published_project.project.prj_owner is not None
249248
assert isinstance(mocked_dask_client.send_computation_tasks, mock.Mock)
249+
assert isinstance(mocked_dask_client.get_tasks_status, mock.Mock)
250+
assert isinstance(mocked_dask_client.get_task_result, mock.Mock)
250251
mocked_dask_client.send_computation_tasks.assert_has_calls(
251252
calls=[
252253
mock.call(
@@ -262,6 +263,12 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]
262263
],
263264
any_order=True,
264265
)
266+
task_to_callback_mapping = {
267+
task.node_id: mocked_dask_client.send_computation_tasks.call_args_list[
268+
i
269+
].kwargs["callback"]
270+
for i, task in enumerate(expected_pending_tasks)
271+
}
265272
mocked_dask_client.send_computation_tasks.reset_mock()
266273
mocked_dask_client.get_tasks_status.assert_not_called()
267274
mocked_dask_client.get_task_result.assert_not_called()
@@ -298,7 +305,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]
298305
)
299306
mocked_dask_client.get_tasks_status.reset_mock()
300307
mocked_dask_client.get_task_result.assert_not_called()
301-
return expected_pending_tasks
308+
return expected_pending_tasks, task_to_callback_mapping
302309

303310

304311
@pytest.fixture
@@ -432,7 +439,7 @@ async def test_proper_pipeline_is_scheduled( # noqa: PLR0915
432439

433440
# -------------------------------------------------------------------------------
434441
# 1. first run will move comp_tasks to PENDING so the dask-worker can take them
435-
expected_pending_tasks = await _assert_schedule_pipeline_PENDING(
442+
expected_pending_tasks, _ = await _assert_publish_in_dask_backend(
436443
sqlalchemy_async_engine,
437444
published_project,
438445
expected_published_tasks,
@@ -895,7 +902,10 @@ async def with_started_project(
895902
#
896903
# 2. This runs the scheduler until the project is started scheduled in the back-end
897904
#
898-
expected_pending_tasks = await _assert_schedule_pipeline_PENDING(
905+
(
906+
expected_pending_tasks,
907+
task_to_callback_mapping,
908+
) = await _assert_publish_in_dask_backend(
899909
sqlalchemy_async_engine,
900910
published_project,
901911
expected_published_tasks,
@@ -920,6 +930,9 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
920930
for job_id in job_ids
921931
]
922932

933+
assert isinstance(mocked_dask_client.get_tasks_status, mock.Mock)
934+
assert isinstance(mocked_dask_client.send_computation_tasks, mock.Mock)
935+
assert isinstance(mocked_dask_client.get_task_result, mock.Mock)
923936
mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running
924937
await scheduler_api.schedule_pipeline(
925938
user_id=run_in_db.user_id,
@@ -977,30 +990,32 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
977990
iteration=run_in_db.iteration,
978991
)
979992
# comp_run, the comp_task switch to STARTED
980-
await assert_comp_runs(
981-
sqlalchemy_async_engine,
982-
expected_total=1,
983-
expected_state=RunningState.STARTED,
984-
where_statement=and_(
985-
comp_runs.c.user_id == published_project.project.prj_owner,
986-
comp_runs.c.project_uuid == f"{published_project.project.uuid}",
987-
),
988-
)
989-
await assert_comp_tasks(
993+
run_in_db = (
994+
await assert_comp_runs(
995+
sqlalchemy_async_engine,
996+
expected_total=1,
997+
expected_state=RunningState.STARTED,
998+
where_statement=and_(
999+
comp_runs.c.user_id == published_project.project.prj_owner,
1000+
comp_runs.c.project_uuid == f"{published_project.project.uuid}",
1001+
),
1002+
)
1003+
)[0]
1004+
tasks_in_db = await assert_comp_tasks(
9901005
sqlalchemy_async_engine,
9911006
project_uuid=published_project.project.uuid,
9921007
task_ids=[exp_started_task.node_id],
9931008
expected_state=RunningState.STARTED,
9941009
expected_progress=0,
9951010
)
996-
await assert_comp_tasks(
1011+
tasks_in_db += await assert_comp_tasks(
9971012
sqlalchemy_async_engine,
9981013
project_uuid=published_project.project.uuid,
9991014
task_ids=[p.node_id for p in expected_pending_tasks],
10001015
expected_state=RunningState.PENDING,
10011016
expected_progress=None,
10021017
)
1003-
await assert_comp_tasks(
1018+
tasks_in_db += await assert_comp_tasks(
10041019
sqlalchemy_async_engine,
10051020
project_uuid=published_project.project.uuid,
10061021
task_ids=[p.node_id for p in expected_published_tasks],
@@ -1033,29 +1048,40 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
10331048
return RunningProject(
10341049
published_project.project,
10351050
published_project.pipeline,
1036-
published_project.tasks,
1051+
tasks_in_db,
10371052
runs=run_in_db,
1053+
task_to_callback_mapping=task_to_callback_mapping,
1054+
)
1055+
1056+
1057+
@pytest.fixture
1058+
def mocked_worker_publisher(mocker: MockerFixture) -> mock.Mock:
1059+
return mocker.patch(
1060+
"simcore_service_director_v2.modules.comp_scheduler._scheduler_base.request_pipeline_scheduling",
1061+
autospec=True,
10381062
)
10391063

10401064

10411065
async def test_completed_task_triggers_new_scheduling_task(
1042-
with_disabled_scheduler_publisher: mock.Mock,
1066+
mocked_worker_publisher: mock.Mock,
10431067
with_started_project: RunningProject,
1044-
initialized_app: FastAPI,
1045-
mocked_dask_client: mock.MagicMock,
1046-
scheduler_api: BaseCompScheduler,
1047-
sqlalchemy_async_engine: AsyncEngine,
1048-
mocker: MockerFixture,
10491068
):
10501069
"""When a pipeline job completes, the Dask backend provides a callback
10511070
that runs in a separate thread. We use that callback to ask the
10521071
director-v2 computational scheduler manager to ask for a new schedule
10531072
After fiddling in distributed source code, here is a similar way to trigger that callback
10541073
"""
1055-
with ThreadPoolExecutor(
1056-
max_workers=1, thread_name_prefix="pytest-callback-thread"
1057-
) as executor:
1058-
...
1074+
completed_node_id = with_started_project.tasks[0].node_id
1075+
callback = with_started_project.task_to_callback_mapping[completed_node_id]
1076+
await asyncio.to_thread(callback)
1077+
1078+
mocked_worker_publisher.assert_called_once_with(
1079+
mock.ANY,
1080+
mock.ANY,
1081+
user_id=with_started_project.runs.user_id,
1082+
project_id=with_started_project.runs.project_uuid,
1083+
iteration=with_started_project.runs.iteration,
1084+
)
10591085

10601086

10611087
async def test_broken_pipeline_configuration_is_not_scheduled_and_aborted(
@@ -1142,7 +1168,7 @@ async def test_task_progress_triggers(
11421168
)
11431169
# -------------------------------------------------------------------------------
11441170
# 1. first run will move comp_tasks to PENDING so the dask-worker can take them
1145-
expected_pending_tasks = await _assert_schedule_pipeline_PENDING(
1171+
expected_pending_tasks, _ = await _assert_publish_in_dask_backend(
11461172
sqlalchemy_async_engine,
11471173
published_project,
11481174
expected_published_tasks,
@@ -1604,7 +1630,7 @@ async def test_running_pipeline_triggers_heartbeat(
16041630
)
16051631
# -------------------------------------------------------------------------------
16061632
# 1. first run will move comp_tasks to PENDING so the dask-worker can take them
1607-
expected_pending_tasks = await _assert_schedule_pipeline_PENDING(
1633+
expected_pending_tasks, _ = await _assert_publish_in_dask_backend(
16081634
sqlalchemy_async_engine,
16091635
published_project,
16101636
expected_published_tasks,

0 commit comments

Comments
 (0)