Skip to content

Commit 3764bf7

Browse files
committed
fixed tests
1 parent ae1d434 commit 3764bf7

File tree

2 files changed

+57
-44
lines changed

2 files changed

+57
-44
lines changed

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from models_library.users import UserID
2424
from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
2525
from servicelib.logging_utils import log_catch
26+
from simcore_service_director_v2.modules.comp_scheduler._utils import (
27+
WAITING_FOR_START_STATES,
28+
)
2629

2730
from ...core.errors import (
2831
ComputationalBackendNotConnectedError,
@@ -381,7 +384,7 @@ async def _task_progress_change_handler(
381384
node_id = task_progress_event.task_owner.node_id
382385
comp_tasks_repo = CompTasksRepository(self.db_engine)
383386
task = await comp_tasks_repo.get_task(project_id, node_id)
384-
if task.progress is None:
387+
if task.state in WAITING_FOR_START_STATES:
385388
task.state = RunningState.STARTED
386389
task.progress = task_progress_event.progress
387390
run = await CompRunsRepository(self.db_engine).get(user_id, project_id)

services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
from simcore_service_director_v2.models.comp_pipelines import CompPipelineAtDB
6161
from simcore_service_director_v2.models.comp_runs import CompRunsAtDB, RunMetadataDict
6262
from simcore_service_director_v2.models.comp_tasks import CompTaskAtDB, Image
63-
from simcore_service_director_v2.models.dask_subsystem import DaskClientTaskState
6463
from simcore_service_director_v2.modules.comp_scheduler._manager import (
6564
run_new_pipeline,
6665
stop_pipeline,
@@ -206,8 +205,8 @@ async def _assert_publish_in_dask_backend(
206205
for p in expected_pending_tasks:
207206
published_tasks.remove(p)
208207

209-
async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]:
210-
return [DaskClientTaskState.PENDING for job_id in job_ids]
208+
async def _return_tasks_pending(job_ids: list[str]) -> list[RunningState]:
209+
return [RunningState.PENDING for job_id in job_ids]
211210

212211
mocked_dask_client.get_tasks_status.side_effect = _return_tasks_pending
213212
assert published_project.project.prj_owner
@@ -445,17 +444,16 @@ async def test_proper_pipeline_is_scheduled( # noqa: PLR0915
445444
)
446445

447446
# -------------------------------------------------------------------------------
448-
# 2.1. the dask-worker might be taking the task, until we get a progress we do not know
449-
# whether it effectively started or it is still queued in the worker process
447+
# 2.1. the dask-worker takes the task
450448
exp_started_task = expected_pending_tasks[0]
451449
expected_pending_tasks.remove(exp_started_task)
452450

453-
async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskState]:
451+
async def _return_1st_task_running(job_ids: list[str]) -> list[RunningState]:
454452
return [
455453
(
456-
DaskClientTaskState.PENDING_OR_STARTED
454+
RunningState.STARTED
457455
if job_id == exp_started_task.job_id
458-
else DaskClientTaskState.PENDING
456+
else RunningState.PENDING
459457
)
460458
for job_id in job_ids
461459
]
@@ -469,7 +467,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
469467
await assert_comp_runs(
470468
sqlalchemy_async_engine,
471469
expected_total=1,
472-
expected_state=RunningState.PENDING,
470+
expected_state=RunningState.STARTED,
473471
where_statement=and_(
474472
comp_runs.c.user_id == published_project.project.prj_owner,
475473
comp_runs.c.project_uuid == f"{published_project.project.uuid}",
@@ -478,8 +476,14 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
478476
await assert_comp_tasks(
479477
sqlalchemy_async_engine,
480478
project_uuid=published_project.project.uuid,
481-
task_ids=[exp_started_task.node_id]
482-
+ [p.node_id for p in expected_pending_tasks],
479+
task_ids=[exp_started_task.node_id],
480+
expected_state=RunningState.STARTED,
481+
expected_progress=None,
482+
)
483+
await assert_comp_tasks(
484+
sqlalchemy_async_engine,
485+
project_uuid=published_project.project.uuid,
486+
task_ids=[p.node_id for p in expected_pending_tasks],
483487
expected_state=RunningState.PENDING,
484488
expected_progress=None,
485489
)
@@ -572,12 +576,12 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
572576

573577
# -------------------------------------------------------------------------------
574578
# 4. the dask-worker completed the task successfully
575-
async def _return_1st_task_success(job_ids: list[str]) -> list[DaskClientTaskState]:
579+
async def _return_1st_task_success(job_ids: list[str]) -> list[RunningState]:
576580
return [
577581
(
578-
DaskClientTaskState.SUCCESS
582+
RunningState.SUCCESS
579583
if job_id == exp_started_task.job_id
580-
else DaskClientTaskState.PENDING
584+
else RunningState.PENDING
581585
)
582586
for job_id in job_ids
583587
]
@@ -679,12 +683,12 @@ async def _return_random_task_result(job_id) -> TaskOutputData:
679683
# 6. the dask-worker starts processing a task
680684
exp_started_task = next_pending_task
681685

682-
async def _return_2nd_task_running(job_ids: list[str]) -> list[DaskClientTaskState]:
686+
async def _return_2nd_task_running(job_ids: list[str]) -> list[RunningState]:
683687
return [
684688
(
685-
DaskClientTaskState.PENDING_OR_STARTED
689+
RunningState.STARTED
686690
if job_id == exp_started_task.job_id
687-
else DaskClientTaskState.PENDING
691+
else RunningState.PENDING
688692
)
689693
for job_id in job_ids
690694
]
@@ -743,12 +747,12 @@ async def _return_2nd_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
743747

744748
# -------------------------------------------------------------------------------
745749
# 7. the task fails
746-
async def _return_2nd_task_failed(job_ids: list[str]) -> list[DaskClientTaskState]:
750+
async def _return_2nd_task_failed(job_ids: list[str]) -> list[RunningState]:
747751
return [
748752
(
749-
DaskClientTaskState.ERRED
753+
RunningState.FAILED
750754
if job_id == exp_started_task.job_id
751-
else DaskClientTaskState.PENDING
755+
else RunningState.PENDING
752756
)
753757
for job_id in job_ids
754758
]
@@ -805,12 +809,12 @@ async def _return_2nd_task_failed(job_ids: list[str]) -> list[DaskClientTaskStat
805809
# 8. the last task shall succeed
806810
exp_started_task = expected_pending_tasks[0]
807811

808-
async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskState]:
812+
async def _return_3rd_task_success(job_ids: list[str]) -> list[RunningState]:
809813
return [
810814
(
811-
DaskClientTaskState.SUCCESS
815+
RunningState.SUCCESS
812816
if job_id == exp_started_task.job_id
813-
else DaskClientTaskState.PENDING
817+
else RunningState.PENDING
814818
)
815819
for job_id in job_ids
816820
]
@@ -917,12 +921,12 @@ async def with_started_project(
917921
exp_started_task = expected_pending_tasks[0]
918922
expected_pending_tasks.remove(exp_started_task)
919923

920-
async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskState]:
924+
async def _return_1st_task_running(job_ids: list[str]) -> list[RunningState]:
921925
return [
922926
(
923-
DaskClientTaskState.PENDING_OR_STARTED
927+
RunningState.STARTED
924928
if job_id == exp_started_task.job_id
925-
else DaskClientTaskState.PENDING
929+
else RunningState.PENDING
926930
)
927931
for job_id in job_ids
928932
]
@@ -939,7 +943,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
939943
await assert_comp_runs(
940944
sqlalchemy_async_engine,
941945
expected_total=1,
942-
expected_state=RunningState.PENDING,
946+
expected_state=RunningState.STARTED,
943947
where_statement=and_(
944948
comp_runs.c.user_id == published_project.project.prj_owner,
945949
comp_runs.c.project_uuid == f"{published_project.project.uuid}",
@@ -948,8 +952,14 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
948952
await assert_comp_tasks(
949953
sqlalchemy_async_engine,
950954
project_uuid=published_project.project.uuid,
951-
task_ids=[exp_started_task.node_id]
952-
+ [p.node_id for p in expected_pending_tasks],
955+
task_ids=[exp_started_task.node_id],
956+
expected_state=RunningState.STARTED,
957+
expected_progress=None,
958+
)
959+
await assert_comp_tasks(
960+
sqlalchemy_async_engine,
961+
project_uuid=published_project.project.uuid,
962+
task_ids=[p.node_id for p in expected_pending_tasks],
953963
expected_state=RunningState.PENDING,
954964
expected_progress=None,
955965
)
@@ -1308,7 +1318,7 @@ async def test_handling_of_disconnected_scheduler_dask(
13081318

13091319
@dataclass(frozen=True, kw_only=True)
13101320
class RebootState:
1311-
dask_task_status: DaskClientTaskState
1321+
dask_task_status: RunningState
13121322
task_result: Exception | TaskOutputData
13131323
expected_task_state_group1: RunningState
13141324
expected_task_progress_group1: float
@@ -1322,7 +1332,7 @@ class RebootState:
13221332
[
13231333
pytest.param(
13241334
RebootState(
1325-
dask_task_status=DaskClientTaskState.LOST,
1335+
dask_task_status=RunningState.UNKNOWN,
13261336
task_result=ComputationalBackendTaskNotFoundError(job_id="fake_job_id"),
13271337
expected_task_state_group1=RunningState.FAILED,
13281338
expected_task_progress_group1=1,
@@ -1334,7 +1344,7 @@ class RebootState:
13341344
),
13351345
pytest.param(
13361346
RebootState(
1337-
dask_task_status=DaskClientTaskState.ABORTED,
1347+
dask_task_status=RunningState.ABORTED,
13381348
task_result=TaskCancelledError(job_id="fake_job_id"),
13391349
expected_task_state_group1=RunningState.ABORTED,
13401350
expected_task_progress_group1=1,
@@ -1346,7 +1356,7 @@ class RebootState:
13461356
),
13471357
pytest.param(
13481358
RebootState(
1349-
dask_task_status=DaskClientTaskState.ERRED,
1359+
dask_task_status=RunningState.FAILED,
13501360
task_result=ValueError("some error during the call"),
13511361
expected_task_state_group1=RunningState.FAILED,
13521362
expected_task_progress_group1=1,
@@ -1358,7 +1368,7 @@ class RebootState:
13581368
),
13591369
pytest.param(
13601370
RebootState(
1361-
dask_task_status=DaskClientTaskState.PENDING_OR_STARTED,
1371+
dask_task_status=RunningState.STARTED,
13621372
task_result=ComputationalBackendTaskResultsNotReadyError(
13631373
job_id="fake_job_id"
13641374
),
@@ -1372,7 +1382,7 @@ class RebootState:
13721382
),
13731383
pytest.param(
13741384
RebootState(
1375-
dask_task_status=DaskClientTaskState.SUCCESS,
1385+
dask_task_status=RunningState.SUCCESS,
13761386
task_result=TaskOutputData.model_validate({"whatever_output": 123}),
13771387
expected_task_state_group1=RunningState.SUCCESS,
13781388
expected_task_progress_group1=1,
@@ -1399,7 +1409,7 @@ async def test_handling_scheduled_tasks_after_director_reboots(
13991409
shall continue scheduling correctly. Even though the task might have continued to run
14001410
in the dask-scheduler."""
14011411

1402-
async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskState]:
1412+
async def mocked_get_tasks_status(job_ids: list[str]) -> list[RunningState]:
14031413
return [reboot_state.dask_task_status for j in job_ids]
14041414

14051415
mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status
@@ -1514,8 +1524,8 @@ async def test_handling_cancellation_of_jobs_after_reboot(
15141524
)
15151525

15161526
# the backend shall report the tasks as running
1517-
async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskState]:
1518-
return [DaskClientTaskState.PENDING_OR_STARTED for j in job_ids]
1527+
async def mocked_get_tasks_status(job_ids: list[str]) -> list[RunningState]:
1528+
return [RunningState.STARTED for j in job_ids]
15191529

15201530
mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status
15211531
# Running the scheduler, should actually cancel the run now
@@ -1559,8 +1569,8 @@ async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskStat
15591569
# the backend shall now report the tasks as aborted
15601570
async def mocked_get_tasks_status_aborted(
15611571
job_ids: list[str],
1562-
) -> list[DaskClientTaskState]:
1563-
return [DaskClientTaskState.ABORTED for j in job_ids]
1572+
) -> list[RunningState]:
1573+
return [RunningState.ABORTED for j in job_ids]
15641574

15651575
mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status_aborted
15661576

@@ -1641,12 +1651,12 @@ async def test_running_pipeline_triggers_heartbeat(
16411651
exp_started_task = expected_pending_tasks[0]
16421652
expected_pending_tasks.remove(exp_started_task)
16431653

1644-
async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskState]:
1654+
async def _return_1st_task_running(job_ids: list[str]) -> list[RunningState]:
16451655
return [
16461656
(
1647-
DaskClientTaskState.PENDING_OR_STARTED
1657+
RunningState.STARTED
16481658
if job_id == exp_started_task.job_id
1649-
else DaskClientTaskState.PENDING
1659+
else RunningState.PENDING
16501660
)
16511661
for job_id in job_ids
16521662
]

0 commit comments

Comments
 (0)