Skip to content

Commit fb3debd

Browse files
authored
Merge branch 'master' into enh/user-notifications
2 parents a7a7170 + cba3237 commit fb3debd

File tree

18 files changed

+231
-137
lines changed

18 files changed

+231
-137
lines changed

packages/service-library/src/servicelib/background_task.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,36 @@ class PeriodicTaskCancellationError(PydanticErrorMixin, Exception):
2525
msg_template: str = "Could not cancel task '{task_name}'"
2626

2727

28+
class SleepUsingAsyncioEvent:
29+
"""Sleep strategy that waits on an event to be set."""
30+
31+
def __init__(self, event: "asyncio.Event") -> None:
32+
self.event = event
33+
34+
async def __call__(self, timeout: float | None) -> None:
35+
with contextlib.suppress(TimeoutError):
36+
await asyncio.wait_for(self.event.wait(), timeout=timeout)
37+
self.event.clear()
38+
39+
2840
async def _periodic_scheduled_task(
2941
task: Callable[..., Awaitable[None]],
3042
*,
3143
interval: datetime.timedelta,
3244
task_name: str,
45+
early_wake_up_event: asyncio.Event | None,
3346
**task_kwargs,
3447
) -> None:
3548
# NOTE: This retries forever unless cancelled
36-
async for attempt in AsyncRetrying(wait=wait_fixed(interval.total_seconds())):
49+
nap = (
50+
asyncio.sleep
51+
if early_wake_up_event is None
52+
else SleepUsingAsyncioEvent(early_wake_up_event)
53+
)
54+
async for attempt in AsyncRetrying(
55+
sleep=nap,
56+
wait=wait_fixed(interval.total_seconds()),
57+
):
3758
with attempt:
3859
with log_context(
3960
_logger,
@@ -51,6 +72,7 @@ def start_periodic_task(
5172
interval: datetime.timedelta,
5273
task_name: str,
5374
wait_before_running: datetime.timedelta = datetime.timedelta(0),
75+
early_wake_up_event: asyncio.Event | None = None,
5476
**kwargs,
5577
) -> asyncio.Task:
5678
with log_context(
@@ -64,6 +86,7 @@ def start_periodic_task(
6486
task,
6587
interval=interval,
6688
task_name=task_name,
89+
early_wake_up_event=early_wake_up_event,
6790
**kwargs,
6891
),
6992
name=task_name,

packages/service-library/tests/test_background_task.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import asyncio
88
import datetime
9-
from typing import AsyncIterator, Awaitable, Callable
9+
from collections.abc import AsyncIterator, Awaitable, Callable
10+
from typing import Final
1011
from unittest import mock
1112

1213
import pytest
@@ -18,38 +19,50 @@
1819
stop_periodic_task,
1920
)
2021

21-
_FAST_POLL_INTERVAL = 1
22+
_FAST_POLL_INTERVAL: Final[int] = 1
23+
_VERY_SLOW_POLL_INTERVAL: Final[int] = 100
2224

2325

2426
@pytest.fixture
2527
def mock_background_task(mocker: MockerFixture) -> mock.AsyncMock:
26-
mocked_task = mocker.AsyncMock(return_value=None)
27-
return mocked_task
28+
return mocker.AsyncMock(return_value=None)
2829

2930

3031
@pytest.fixture
3132
def task_interval() -> datetime.timedelta:
3233
return datetime.timedelta(seconds=_FAST_POLL_INTERVAL)
3334

3435

35-
@pytest.fixture(params=[None, 1])
36+
@pytest.fixture
37+
def very_long_task_interval() -> datetime.timedelta:
38+
return datetime.timedelta(seconds=_VERY_SLOW_POLL_INTERVAL)
39+
40+
41+
@pytest.fixture(params=[None, 1], ids=lambda x: f"stop-timeout={x}")
3642
def stop_task_timeout(request: pytest.FixtureRequest) -> float | None:
3743
return request.param
3844

3945

4046
@pytest.fixture
4147
async def create_background_task(
4248
faker: Faker, stop_task_timeout: float | None
43-
) -> AsyncIterator[Callable[[datetime.timedelta, Callable], Awaitable[asyncio.Task]]]:
49+
) -> AsyncIterator[
50+
Callable[
51+
[datetime.timedelta, Callable, asyncio.Event | None], Awaitable[asyncio.Task]
52+
]
53+
]:
4454
created_tasks = []
4555

4656
async def _creator(
47-
interval: datetime.timedelta, task: Callable[..., Awaitable]
57+
interval: datetime.timedelta,
58+
task: Callable[..., Awaitable],
59+
early_wake_up_event: asyncio.Event | None,
4860
) -> asyncio.Task:
4961
background_task = start_periodic_task(
5062
task,
5163
interval=interval,
5264
task_name=faker.pystr(),
65+
early_wake_up_event=early_wake_up_event,
5366
)
5467
assert background_task
5568
created_tasks.append(background_task)
@@ -62,33 +75,69 @@ async def _creator(
6275
)
6376

6477

78+
@pytest.mark.parametrize(
79+
"wake_up_event", [None, asyncio.Event], ids=lambda x: f"wake-up-event: {x}"
80+
)
6581
async def test_background_task_created_and_deleted(
6682
mock_background_task: mock.AsyncMock,
6783
task_interval: datetime.timedelta,
6884
create_background_task: Callable[
69-
[datetime.timedelta, Callable], Awaitable[asyncio.Task]
85+
[datetime.timedelta, Callable, asyncio.Event | None], Awaitable[asyncio.Task]
7086
],
87+
wake_up_event: Callable | None,
7188
):
72-
task = await create_background_task(
89+
event = wake_up_event() if wake_up_event else None
90+
_task = await create_background_task(
7391
task_interval,
7492
mock_background_task,
93+
event,
7594
)
7695
await asyncio.sleep(5 * task_interval.total_seconds())
7796
mock_background_task.assert_called()
78-
assert mock_background_task.call_count > 1
97+
assert mock_background_task.call_count > 2
98+
99+
100+
async def test_background_task_wakes_up_early(
101+
mock_background_task: mock.AsyncMock,
102+
very_long_task_interval: datetime.timedelta,
103+
create_background_task: Callable[
104+
[datetime.timedelta, Callable, asyncio.Event | None], Awaitable[asyncio.Task]
105+
],
106+
):
107+
wake_up_event = asyncio.Event()
108+
_task = await create_background_task(
109+
very_long_task_interval,
110+
mock_background_task,
111+
wake_up_event,
112+
)
113+
await asyncio.sleep(5 * _FAST_POLL_INTERVAL)
114+
# now the task should have run only once
115+
mock_background_task.assert_called_once()
116+
await asyncio.sleep(5 * _FAST_POLL_INTERVAL)
117+
mock_background_task.assert_called_once()
118+
# this should wake up the task
119+
wake_up_event.set()
120+
await asyncio.sleep(5 * _FAST_POLL_INTERVAL)
121+
mock_background_task.assert_called()
122+
assert mock_background_task.call_count == 2
123+
# no change this now waits again a very long time
124+
await asyncio.sleep(5 * _FAST_POLL_INTERVAL)
125+
mock_background_task.assert_called()
126+
assert mock_background_task.call_count == 2
79127

80128

81129
async def test_background_task_raises_restarts(
82130
mock_background_task: mock.AsyncMock,
83131
task_interval: datetime.timedelta,
84132
create_background_task: Callable[
85-
[datetime.timedelta, Callable], Awaitable[asyncio.Task]
133+
[datetime.timedelta, Callable, asyncio.Event | None], Awaitable[asyncio.Task]
86134
],
87135
):
88136
mock_background_task.side_effect = RuntimeError("pytest faked runtime error")
89-
task = await create_background_task(
137+
_task = await create_background_task(
90138
task_interval,
91139
mock_background_task,
140+
None,
92141
)
93142
await asyncio.sleep(5 * task_interval.total_seconds())
94143
mock_background_task.assert_called()
@@ -99,13 +148,14 @@ async def test_background_task_correctly_cancels(
99148
mock_background_task: mock.AsyncMock,
100149
task_interval: datetime.timedelta,
101150
create_background_task: Callable[
102-
[datetime.timedelta, Callable], Awaitable[asyncio.Task]
151+
[datetime.timedelta, Callable, asyncio.Event | None], Awaitable[asyncio.Task]
103152
],
104153
):
105154
mock_background_task.side_effect = asyncio.CancelledError
106-
task = await create_background_task(
155+
_task = await create_background_task(
107156
task_interval,
108157
mock_background_task,
158+
None,
109159
)
110160
await asyncio.sleep(5 * task_interval.total_seconds())
111161
# the task will be called once, and then stop

services/director-v2/src/simcore_service_director_v2/api/dependencies/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from fastapi import Depends, FastAPI, Request
22

33
from ...core.settings import ComputationalBackendSettings
4-
from ...modules.comp_scheduler.base_scheduler import BaseCompScheduler
4+
from ...modules.comp_scheduler import BaseCompScheduler
55
from . import get_app
66

77

services/director-v2/src/simcore_service_director_v2/api/routes/computations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from ...models.comp_runs import CompRunsAtDB, ProjectMetadataDict, RunMetadataDict
6464
from ...models.comp_tasks import CompTaskAtDB
6565
from ...modules.catalog import CatalogClient
66-
from ...modules.comp_scheduler.base_scheduler import BaseCompScheduler
66+
from ...modules.comp_scheduler import BaseCompScheduler
6767
from ...modules.db.repositories.clusters import ClustersRepository
6868
from ...modules.db.repositories.comp_pipelines import CompPipelinesRepository
6969
from ...modules.db.repositories.comp_runs import CompRunsRepository
@@ -288,7 +288,7 @@ async def _try_start_pipeline(
288288
)
289289
# NOTE: in case of a burst of calls to that endpoint, we might end up in a weird state.
290290
@run_sequentially_in_context(target_args=["computation.project_id"])
291-
async def create_computation( # noqa: PLR0913 # pylint:disable=too-many-positional-arguments
291+
async def create_computation( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
292292
computation: ComputationCreate,
293293
request: Request,
294294
project_repo: Annotated[

services/director-v2/src/simcore_service_director_v2/core/application.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def init_app(settings: AppSettings | None = None) -> FastAPI:
173173
)
174174
if dynamic_scheduler_enabled or computational_backend_enabled:
175175
rabbitmq.setup(app)
176+
redis.setup(app)
176177

177178
if dynamic_scheduler_enabled:
178-
redis.setup(app)
179179
dynamic_sidecar.setup(app)
180180
socketio.setup(app)
181181
notifier.setup(app)
Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1-
from .background_task import setup
1+
from fastapi import FastAPI
22

3-
__all__: tuple[str, ...] = ("setup",)
3+
from ._base_scheduler import BaseCompScheduler
4+
from ._task import on_app_shutdown, on_app_startup
5+
6+
7+
def setup(app: FastAPI):
8+
app.add_event_handler("startup", on_app_startup(app))
9+
app.add_event_handler("shutdown", on_app_shutdown(app))
10+
11+
12+
__all__: tuple[str, ...] = (
13+
"setup",
14+
"BaseCompScheduler",
15+
)

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/base_scheduler.py renamed to services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py

File renamed without changes.

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py renamed to services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@
3030
from ...models.comp_runs import RunMetadataDict
3131
from ...models.comp_tasks import CompTaskAtDB
3232
from ...models.dask_subsystem import DaskClientTaskState
33-
from ...modules.dask_client import DaskClient, PublishedComputationTask
34-
from ...modules.dask_clients_pool import DaskClientsPool
35-
from ...modules.db.repositories.clusters import ClustersRepository
36-
from ...modules.db.repositories.comp_runs import CompRunsRepository
3733
from ...utils.comp_scheduler import Iteration, get_resource_tracking_run_id
3834
from ...utils.dask import (
3935
clean_task_output_and_log_files_if_invalid,
@@ -48,8 +44,12 @@
4844
publish_service_stopped_metrics,
4945
)
5046
from ..clusters_keeper import get_or_create_on_demand_cluster
47+
from ..dask_client import DaskClient, PublishedComputationTask
48+
from ..dask_clients_pool import DaskClientsPool
49+
from ..db.repositories.clusters import ClustersRepository
50+
from ..db.repositories.comp_runs import CompRunsRepository
5151
from ..db.repositories.comp_tasks import CompTasksRepository
52-
from .base_scheduler import BaseCompScheduler, ScheduledPipelineParams
52+
from ._base_scheduler import BaseCompScheduler, ScheduledPipelineParams
5353

5454
_logger = logging.getLogger(__name__)
5555

@@ -158,9 +158,11 @@ async def _get_tasks_status(
158158
for dask_task_state, task in zip(tasks_statuses, tasks, strict=True):
159159
if dask_task_state is DaskClientTaskState.PENDING_OR_STARTED:
160160
running_states += [
161-
RunningState.STARTED
162-
if task.progress is not None
163-
else RunningState.PENDING
161+
(
162+
RunningState.STARTED
163+
if task.progress is not None
164+
else RunningState.PENDING
165+
)
164166
]
165167
else:
166168
running_states += [

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/factory.py renamed to services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22

33
from fastapi import FastAPI
44
from models_library.clusters import DEFAULT_CLUSTER_ID
5-
from simcore_service_director_v2.core.settings import AppSettings
65

76
from ...core.errors import ConfigurationError
7+
from ...core.settings import AppSettings
88
from ...models.comp_runs import CompRunsAtDB
9-
from ...modules.dask_clients_pool import DaskClientsPool
10-
from ...modules.rabbitmq import get_rabbitmq_client, get_rabbitmq_rpc_client
119
from ...utils.comp_scheduler import SCHEDULED_STATES
10+
from ..dask_clients_pool import DaskClientsPool
1211
from ..db.repositories.comp_runs import CompRunsRepository
13-
from .base_scheduler import BaseCompScheduler, ScheduledPipelineParams
14-
from .dask_scheduler import DaskScheduler
12+
from ..rabbitmq import get_rabbitmq_client, get_rabbitmq_rpc_client
13+
from ._base_scheduler import BaseCompScheduler, ScheduledPipelineParams
14+
from ._dask_scheduler import DaskScheduler
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -43,9 +43,9 @@ async def create_from_db(app: FastAPI) -> BaseCompScheduler:
4343
db_engine=db_engine,
4444
scheduled_pipelines={
4545
(r.user_id, r.project_uuid, r.iteration): ScheduledPipelineParams(
46-
cluster_id=r.cluster_id
47-
if r.cluster_id is not None
48-
else DEFAULT_CLUSTER_ID,
46+
cluster_id=(
47+
r.cluster_id if r.cluster_id is not None else DEFAULT_CLUSTER_ID
48+
),
4949
run_metadata=r.metadata,
5050
mark_for_cancellation=False,
5151
use_on_demand_clusters=r.use_on_demand_clusters,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import datetime
2+
import logging
3+
from collections.abc import Callable, Coroutine
4+
from typing import Any, Final
5+
6+
from fastapi import FastAPI
7+
from servicelib.background_task import start_periodic_task, stop_periodic_task
8+
from servicelib.logging_utils import log_context
9+
from servicelib.redis import RedisClientsManager
10+
from servicelib.redis_utils import exclusive
11+
from settings_library.redis import RedisDatabase
12+
13+
from . import _scheduler_factory
14+
15+
_logger = logging.getLogger(__name__)
16+
17+
_COMPUTATIONAL_SCHEDULER_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(
18+
seconds=5
19+
)
20+
_TASK_NAME: Final[str] = "computational services scheduler"
21+
22+
23+
def on_app_startup(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]:
24+
async def start_scheduler() -> None:
25+
with log_context(
26+
_logger, level=logging.INFO, msg="starting computational scheduler"
27+
):
28+
redis_clients_manager: RedisClientsManager = app.state.redis_clients_manager
29+
lock_key = f"{app.title}:computational_scheduler"
30+
app.state.scheduler = scheduler = await _scheduler_factory.create_from_db(
31+
app
32+
)
33+
app.state.computational_scheduler_task = start_periodic_task(
34+
exclusive(
35+
redis_clients_manager.client(RedisDatabase.LOCKS),
36+
lock_key=lock_key,
37+
)(scheduler.schedule_all_pipelines),
38+
interval=_COMPUTATIONAL_SCHEDULER_INTERVAL,
39+
task_name=_TASK_NAME,
40+
early_wake_up_event=scheduler.wake_up_event,
41+
)
42+
43+
return start_scheduler
44+
45+
46+
def on_app_shutdown(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]:
47+
async def stop_scheduler() -> None:
48+
await stop_periodic_task(app.state.computational_scheduler_task)
49+
50+
return stop_scheduler

0 commit comments

Comments
 (0)