Skip to content

Commit a954c4a

Browse files
committed
added life cycle plugins
1 parent 7407140 commit a954c4a

File tree

4 files changed

+114
-1
lines changed

4 files changed

+114
-1
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import logging
2+
from collections.abc import Awaitable
3+
from typing import Any
4+
5+
from dask.typing import Key
6+
from distributed import Worker, WorkerPlugin
7+
from distributed.scheduler import TaskStateState
8+
from servicelib.logging_utils import log_context
9+
10+
_logger = logging.getLogger(__name__)
11+
12+
13+
class TaskLifecycleWorkerPlugin(WorkerPlugin):
14+
def __init__(self) -> None:
15+
with log_context(
16+
_logger,
17+
logging.INFO,
18+
"TaskLifecycleWorkerPlugin init",
19+
):
20+
self.worker = None
21+
22+
def setup(self, worker: Worker) -> Awaitable[None]:
23+
async def _() -> None:
24+
with log_context(
25+
_logger,
26+
logging.INFO,
27+
"TaskLifecycleWorkerPlugin start",
28+
):
29+
self.worker = worker
30+
31+
return _()
32+
33+
def transition(
34+
self,
35+
key: Key,
36+
start: TaskStateState,
37+
finish: TaskStateState,
38+
**kwargs: Any,
39+
):
40+
# Start state: one of released, waiting, processing, memory, error
41+
with log_context(
42+
_logger,
43+
logging.INFO,
44+
f"Task {key} transition from {start} to {finish}",
45+
):
46+
assert self.worker # nosec
47+
self.worker.log_event(
48+
f"task-lifecycle-{key}",
49+
{
50+
"key": key,
51+
"worker": kwargs.get("worker"),
52+
"start": start,
53+
"finish": finish,
54+
},
55+
)

services/dask-sidecar/src/simcore_service_dask_sidecar/worker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from .computational_sidecar.core import ComputationalSidecar
1919
from .rabbitmq_plugin import RabbitMQPlugin
2020
from .settings import ApplicationSettings
21+
from .task_life_cycle_worker_plugin import (
22+
TaskLifecycleWorkerPlugin,
23+
)
2124
from .utils.dask import (
2225
TaskPublisher,
2326
get_current_task_resources,
@@ -78,6 +81,11 @@ async def dask_setup(worker: distributed.Worker) -> None:
7881
except Exception:
7982
await worker.close(reason="failed to add RabbitMQ plugin")
8083
raise
84+
try:
85+
await worker.plugin_add(TaskLifecycleWorkerPlugin(), catch_errors=False)
86+
except Exception:
87+
await worker.close(reason="failed to add TaskLifecycleWorkerPlugin")
88+
raise
8189

8290
print_dask_sidecar_banner()
8391

services/dask-sidecar/tests/unit/test_rabbitmq_plugin.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
# pylint: disable=unused-variable
44
# pylint: disable=no-member
55

6+
import asyncio
7+
from unittest import mock
8+
69
import distributed
10+
import pytest
11+
from pytest_mock import MockerFixture
712

813
# Selection of core and tool services started in this swarm fixture (integration)
914
pytest_simcore_core_services_selection = [
@@ -14,3 +19,19 @@
1419

1520

1621
def test_rabbitmq_plugin_initializes(dask_client: distributed.Client): ...
22+
23+
24+
@pytest.fixture
25+
def erroring_rabbitmq_plugin(mocker: MockerFixture) -> mock.Mock:
26+
return mocker.patch(
27+
"simcore_service_dask_sidecar.rabbitmq_plugin.RabbitMQPlugin",
28+
autospec=True,
29+
side_effect=RuntimeError("Pytest: RabbitMQ plugin initialization failed"),
30+
)
31+
32+
33+
async def test_dask_worker_closes_if_plugin_fails_on_start(
34+
erroring_rabbitmq_plugin: mock.Mock,
35+
local_cluster: distributed.LocalCluster,
36+
):
37+
await asyncio.sleep(10)
Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,38 @@
1+
import time
2+
13
import distributed
4+
import pytest
25

36
pytest_simcore_core_services_selection = [
47
"rabbit",
58
]
69

710

811
def test_scheduler(dask_client: distributed.Client) -> None:
9-
assert True
12+
def _some_task() -> int:
13+
time.sleep(1)
14+
return 2
15+
16+
def _some_failing_task() -> None:
17+
time.sleep(1)
18+
msg = "Some error"
19+
raise RuntimeError(msg)
20+
21+
future = dask_client.submit(_some_task)
22+
assert future.result(timeout=10) == 2
23+
events = dask_client.get_events(f"task-lifecycle-{future.key}")
24+
print("XXXX received events:")
25+
assert events
26+
assert isinstance(events, tuple)
27+
for event in events:
28+
print(f"\t{event}")
29+
30+
future = dask_client.submit(_some_failing_task)
31+
with pytest.raises(RuntimeError):
32+
future.result(timeout=10)
33+
events = dask_client.get_events(f"task-lifecycle-{future.key}")
34+
print("XXXX received events:")
35+
assert events
36+
assert isinstance(events, tuple)
37+
for event in events:
38+
print(f"\t{event}")

0 commit comments

Comments
 (0)