diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 233fa0d969..11391df2b6 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -252,7 +252,8 @@ def _update_worker_status(self, op, msg): self.scheduler_info["workers"].update(workers) self.scheduler_info.update(msg) elif op == "remove": - del self.scheduler_info["workers"][msg] + worker = msg["worker"] + self.scheduler_info["workers"].pop(worker, None) else: # pragma: no cover raise ValueError("Invalid op", op, msg) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 7da310a2e6..113847a2d0 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -397,12 +397,12 @@ async def _correct_state_internal(self) -> None: def _update_worker_status(self, op, msg): if op == "remove": - name = self.scheduler_info["workers"][msg]["name"] + name = msg["name"] def f(): if ( name in self.workers - and msg not in self.scheduler_info["workers"] + and msg["worker"] not in self.scheduler_info["workers"] and not any( d["name"] == name for d in self.scheduler_info["workers"].values() diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 2c74e494e6..2e37e36afe 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -11,7 +11,7 @@ import tempfile import uuid import zipfile -from collections.abc import Awaitable +from collections.abc import Awaitable, Hashable from typing import TYPE_CHECKING, Any, Callable, ClassVar from dask.typing import Key @@ -190,7 +190,13 @@ def add_worker(self, scheduler: Scheduler, worker: str) -> None | Awaitable[None """ def remove_worker( - self, scheduler: Scheduler, worker: str, *, stimulus_id: str, **kwargs: Any + self, + scheduler: Scheduler, + worker: str, + *, + name: Hashable, + stimulus_id: str, + **kwargs: Any, ) -> None | Awaitable[None]: """Run when a worker leaves the cluster diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 64c943f090..642d897f98 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5563,7 +5563,10 @@ async def remove_worker( try: try: result = plugin.remove_worker( - scheduler=self, worker=address, stimulus_id=stimulus_id + scheduler=self, + worker=address, + name=ws.name, + stimulus_id=stimulus_id, ) except TypeError: parameters = inspect.signature(plugin.remove_worker).parameters @@ -9410,9 +9413,15 @@ def add_worker(self, scheduler: Scheduler, worker: str) -> None: except CommClosedError: scheduler.remove_plugin(name=self.name) - def remove_worker(self, scheduler: Scheduler, worker: str, **kwargs: Any) -> None: + def remove_worker( + self, scheduler: Scheduler, worker: str, name: Hashable, **kwargs: Any + ) -> None: try: - self.bcomm.send(["remove", worker]) + msg = { + "worker": worker, + "name": name, + } + self.bcomm.send(["remove", msg]) except CommClosedError: scheduler.remove_plugin(name=self.name) diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 565f3d9792..ac0eba4a68 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import copy import random from contextlib import suppress from operator import add @@ -11,7 +12,7 @@ from dask import delayed -from distributed import Client, Nanny, Worker, wait +from distributed import Client, Nanny, Scheduler, SpecCluster, Worker, wait from distributed.chaos import KillWorker from distributed.compatibility import WINDOWS from distributed.metrics import time @@ -337,3 +338,63 @@ async def test_chaos_rechunk(c, s, *workers): await asyncio.sleep(0.1) await z.cancel() + + +@pytest.mark.slow +def test_stress_scale(monkeypatch): + cluster_kwargs = {} + client_kwargs = { + "set_as_default": False, + } + # No idea how else to handle contestion of port 8787 + scheduler_spec = { + "cls": Scheduler, + "options": {"dashboard": False, "dashboard_address": 9876}, + } + spec = {} + template = {"cls": Nanny} + N = 5 + for i in range(N): + w = spec[f"worker-{i}"] = copy.copy(template) + try: + cluster = SpecCluster( + scheduler=scheduler_spec, + workers=spec, + worker=template, # <- template for newly scaled up workers + **cluster_kwargs, + ) + # Introduce a delay in worker status message processing and allow + # other async code to run in the meantime by monkeypatching the + # read() function with an asyncio.sleep(). This slight delay greatly + # increases the likelihood of discrepancies in worker inventory + # tracking.The chosen time is an empirical compromise between enough + # delay to cause discrepancies and not too much delay so that + # messages still arrive in time. + comm = cluster._watch_worker_status_comm + old_read = comm.read + + async def new_read(): + res = await old_read() + await asyncio.sleep(0.2) + return res + + monkeypatch.setattr(comm, "read", new_read) + + client = Client(cluster, **client_kwargs) + client.wait_for_workers(N) + + # scale down: + print("down") + client.cluster.scale(n=1) + + # scale up again: + print("up") + # client.cluster.worker_spec = copy.copy(spec) + client.cluster.scale(n=len(spec)) + client.wait_for_workers(len(spec)) + + # shutdown: + print("shutdown") + client.close() + finally: + client.cluster.close(timeout=30)