diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 7da310a2e6..963e526316 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -7,10 +7,10 @@ import logging import math import weakref -from collections.abc import Awaitable, Generator +from collections.abc import Awaitable, Generator, Iterable from contextlib import suppress from inspect import isawaitable -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast from tornado import gen from tornado.ioloop import IOLoop @@ -137,16 +137,53 @@ class does handle all of the logic around asynchronously cleanly setting up and tearing things down at the right times. Hopefully it can form a base for other more user-centric classes. + Terminology + ----------- + **Spec name**: The string key in the ``worker_spec`` dictionary (e.g., ``"0"``, + ``"my-worker"``). This identifies a worker specification entry. + + **Worker name**: The actual name a worker reports to the scheduler (e.g., ``"0"``, + ``"0-0"``, ``"0-1"``). This is what appears in ``scheduler.workers``. + + For **regular workers**: spec name == worker name (one-to-one mapping) + For **grouped workers**: one spec name → multiple worker names (one-to-many mapping) + + **Important**: The ``self.workers`` dict is always keyed by **spec names** (not worker + names), mapping to Worker class instances. When accessing this dict with + a worker name from the scheduler, you must first map it to a spec name using + ``_worker_name_to_spec_name()``. + + Grouped Workers + --------------- + A single spec entry can generate multiple Dask workers by including a ``"group"`` + element with suffixes. This is useful for: + - HPC systems (e.g., SLURM) where multiple processes are allocated together + - Any worker class that manages multiple workers as a unit (e.g., MultiWorker) + + >>> cluster.worker_spec = { + ... "0": {"cls": MultiWorker, "options": {"processes": 3}, "group": ["-0", "-1", "-2"]}, + ... "1": {"cls": MultiWorker, "options": {"processes": 2}, "group": ["-0", "-1"]} + ... } + + The scheduler sees individual workers with concatenated names: + + >>> [ws.name for ws in cluster.scheduler.workers.values()] + ["0-0", "0-1", "0-2", "1-0", "1-1"] + + When any worker in a group fails, the entire spec is removed so the group + can be recreated as a unit (important for HPC where the whole allocation fails). + Parameters ---------- - workers: dict - A dictionary mapping names to worker classes and their specifications - See example below + workers: dict[str, dict], optional + A dictionary mapping spec names (strings) to worker specifications. + Each worker spec is a dict with 'cls' and optionally 'options' and 'group'. + Spec names must be strings. scheduler: dict, optional - A similar mapping for a scheduler - worker: dict - A specification of a single worker. - This is used for any new workers that are created. + A specification for the scheduler with 'cls' and 'options' keys + worker: dict, optional + A worker specification template used when calling scale(). + This template is used to auto-generate new worker specs. asynchronous: bool If this is intended to be used directly within an event loop with async/await @@ -161,17 +198,17 @@ class does handle all of the logic around asynchronously cleanly setting up Examples -------- - To create a SpecCluster you specify how to set up a Scheduler and Workers + To create a SpecCluster you specify worker specifications and a scheduler spec >>> from dask.distributed import Scheduler, Worker, Nanny >>> scheduler = {'cls': Scheduler, 'options': {"dashboard_address": ':8787'}} - >>> workers = { + >>> worker_spec = { ... 'my-worker': {"cls": Worker, "options": {"nthreads": 1}}, ... 'my-nanny': {"cls": Nanny, "options": {"nthreads": 2}}, ... } - >>> cluster = SpecCluster(scheduler=scheduler, workers=workers) + >>> cluster = SpecCluster(scheduler=scheduler, workers=worker_spec) - The worker spec is stored as the ``.worker_spec`` attribute + The worker specs are stored in the ``.worker_spec`` attribute >>> cluster.worker_spec { @@ -179,8 +216,8 @@ class does handle all of the logic around asynchronously cleanly setting up 'my-nanny': {"cls": Nanny, "options": {"nthreads": 2}}, } - While the instantiation of this spec is stored in the ``.workers`` - attribute + The actual Worker instances created from these specs are stored in the + ``.workers`` attribute >>> cluster.workers { @@ -188,53 +225,36 @@ class does handle all of the logic around asynchronously cleanly setting up 'my-nanny': } - Should the spec change, we can await the cluster or call the - ``._correct_state`` method to align the actual state to the specified - state. + Should the worker_spec change, we can await the cluster or call the + ``._correct_state`` method to align the actual Worker instances to the + specified state. - We can also ``.scale(...)`` the cluster, which adds new workers of a given - form. + We can also ``.scale(...)`` the cluster, which adds new worker specs using + the template provided via the ``worker`` parameter. - >>> worker = {'cls': Worker, 'options': {}} - >>> cluster = SpecCluster(scheduler=scheduler, worker=worker) + >>> worker_template = {'cls': Worker, 'options': {}} + >>> cluster = SpecCluster(scheduler=scheduler, worker=worker_template) >>> cluster.worker_spec {} >>> cluster.scale(3) >>> cluster.worker_spec { - 0: {'cls': Worker, 'options': {}}, - 1: {'cls': Worker, 'options': {}}, - 2: {'cls': Worker, 'options': {}}, + "0": {'cls': Worker, 'options': {}}, + "1": {'cls': Worker, 'options': {}}, + "2": {'cls': Worker, 'options': {}}, } Note that above we are using the standard ``Worker`` and ``Nanny`` classes, however in practice other classes could be used that handle resource - management like ``KubernetesPod`` or ``SLURMJob``. The spec does not need - to conform to the expectations of the standard Dask Worker class. It just - needs to be called with the provided options, support ``__await__`` and - ``close`` methods and the ``worker_address`` property.. - - Also note that uniformity of the specification is not required. Other API - could be added externally (in subclasses) that adds workers of different - specifications into the same dictionary. - - If a single entry in the spec will generate multiple dask workers then - please provide a `"group"` element to the spec, that includes the suffixes - that will be added to each name (this should be handled by your worker - class). - - >>> cluster.worker_spec - { - 0: {"cls": MultiWorker, "options": {"processes": 3}, "group": ["-0", "-1", -2"]} - 1: {"cls": MultiWorker, "options": {"processes": 2}, "group": ["-0", "-1"]} - } - - These suffixes should correspond to the names used by the workers when - they deploy. - - >>> [ws.name for ws in cluster.scheduler.workers.values()] - ["0-0", "0-1", "0-2", "1-0", "1-1"] + management like ``KubernetesPod`` or ``SLURMJob``. Worker specs do not need + to conform to the expectations of the standard Dask Worker class. They just + need to be called with the provided options, support ``__await__`` and + ``close`` methods and the ``worker_address`` property. + + Also note that uniformity of worker specs is not required. Other API + could be added externally (in subclasses) that adds worker specs of different + types into the same worker_spec dictionary. """ _instances: ClassVar[weakref.WeakSet[SpecCluster]] = weakref.WeakSet() @@ -260,10 +280,10 @@ def __init__( self._created = weakref.WeakSet() self.scheduler_spec = copy.copy(scheduler) - self.worker_spec = copy.copy(workers) or {} - self.new_spec = copy.copy(worker) + self.worker_spec: dict[str, dict[str, Any]] = copy.copy(workers) or {} + self.new_spec: dict[str, Any] | None = copy.copy(worker) self.scheduler = None - self.workers = {} + self.workers: dict[str, Worker | Nanny] = {} self._i = 0 self.security = security or Security() self._futures = set() @@ -356,7 +376,21 @@ async def _correct_state_internal(self) -> None: to_close = set(self.workers) - set(self.worker_spec) if to_close: if self.scheduler.status == Status.running: - await self.scheduler_comm.retire_workers(workers=list(to_close)) + # Map spec names to worker names for retirement + workers_to_retire: list[str] = [] + for spec_name in to_close: + worker_names = self._spec_name_to_worker_names(spec_name) + # Only retire workers that actually exist in the scheduler + scheduler_worker_names = { + w["name"] for w in self.scheduler_info["workers"].values() + } + workers_to_retire.extend(worker_names & scheduler_worker_names) + + if workers_to_retire: + await self.scheduler_comm.retire_workers( + workers=workers_to_retire + ) + tasks = [ asyncio.create_task(self.workers[w].close()) for w in to_close @@ -400,16 +434,34 @@ def _update_worker_status(self, op, msg): name = self.scheduler_info["workers"][msg]["name"] def f(): - if ( - name in self.workers - and msg not in self.scheduler_info["workers"] - and not any( - d["name"] == name - for d in self.scheduler_info["workers"].values() - ) + # Find the spec this worker belongs to + spec_name = self._worker_name_to_spec_name(name) + + # Check if worker/spec is still missing (not re-registered) + if msg not in self.scheduler_info["workers"] and not any( + d["name"] == name for d in self.scheduler_info["workers"].values() ): - self._futures.add(asyncio.ensure_future(self.workers[name].close())) - del self.workers[name] + # For regular workers: close and remove from self.workers + if spec_name and spec_name == name and name in self.workers: + self._futures.add( + asyncio.ensure_future(self.workers[name].close()) + ) + del self.workers[name] + + # For grouped workers: remove the entire spec + if spec_name and spec_name in self.worker_spec: + spec = self.worker_spec[spec_name] + if "group" in spec: + # Close the MultiWorker instance + if spec_name in self.workers: + self._futures.add( + asyncio.ensure_future( + self.workers[spec_name].close() + ) + ) + del self.workers[spec_name] + # Remove the spec so adaptive can recreate it + del self.worker_spec[spec_name] delay = parse_timedelta( dask.config.get("distributed.deploy.lost-worker-timeout") @@ -486,17 +538,28 @@ async def __aenter__(self): raise def _threads_per_worker(self) -> int: - """Return the number of threads per worker for new workers""" + """Return the number of threads per worker for new workers. + + For grouped workers, this returns the threads per individual worker + (total spec threads divided by number of workers in the group). + """ if not self.new_spec: # pragma: no cover raise ValueError("To scale by cores= you must specify cores per worker") for name in ["nthreads", "ncores", "threads", "cores"]: with suppress(KeyError): - return self.new_spec["options"][name] + total_threads = self.new_spec["options"][name] + # For grouped workers, divide by number of workers in the group + workers_per_spec = self._workers_per_spec(self.new_spec) + return total_threads // workers_per_spec raise RuntimeError("unreachable") def _memory_per_worker(self) -> int: - """Return the memory limit per worker for new workers""" + """Return the memory limit per worker for new workers. + + For grouped workers, this returns the memory per individual worker + (total spec memory divided by number of workers in the group). + """ if not self.new_spec: # pragma: no cover raise ValueError( "to scale by memory= your worker definition must include a " @@ -505,85 +568,349 @@ def _memory_per_worker(self) -> int: for name in ["memory_limit", "memory"]: with suppress(KeyError): - return parse_bytes(self.new_spec["options"][name]) + total_memory = parse_bytes(self.new_spec["options"][name]) + # For grouped workers, divide by number of workers in the group + workers_per_spec = self._workers_per_spec(self.new_spec) + return total_memory // workers_per_spec raise ValueError( "to use scale(memory=...) your worker definition must include a " "memory_limit definition" ) + def _count_workers_in_specs(self) -> int: + """Count total number of workers across all specs. + + For regular workers, each spec = 1 worker. + For grouped workers, each spec = number of group members. + + Returns + ------- + int + Total number of workers that would be created by current worker_spec + """ + total = 0 + for _spec_name, spec in self.worker_spec.items(): + if "group" in spec: + total += len(spec["group"]) + else: + total += 1 + return total + + def _workers_per_spec(self, spec: dict[str, Any]) -> int: + """Get number of workers a single spec will create. + + Parameters + ---------- + spec : dict + Worker specification dict + + Returns + ------- + int + Number of workers this spec creates (1 for regular, len(group) for grouped) + """ + if "group" in spec: + return len(spec["group"]) + return 1 + def scale(self, n=0, memory=None, cores=None): + """Scale cluster to a target number of workers or resource level. + + Parameters + ---------- + n : int, optional + Target maximum number of workers. For grouped workers, rounds down to + complete specs. Default is 0. + memory : str, optional + Target total memory (e.g., "10 GB"). Scales conservatively - will NOT + exceed this limit. For grouped workers, rounds down to complete specs. + cores : int, optional + Target total cores/threads. Scales conservatively - will NOT exceed + this limit. For grouped workers, rounds down to complete specs. + + Notes + ----- + **All scaling is conservative (rounds down to number of complete specs):** + - Ensures limits are not exceeded (except special case below) + - Prevents resource overcommitment and surprises + - For grouped workers, may get fewer workers than requested + + **Special case - minimum viability:** + - If target > 0 and current workers = 0, creates at least 1 spec + - Prevents deadlock where no workers can be created + - Example: `scale(1)` with 2-worker specs → 1 spec = 2 workers (exceeds target!) + + **Examples:** + - `scale(5)` with 2-worker specs → 2 specs = 4 workers (not 5) + - `scale(1)` with 2-worker specs → 1 spec = 2 workers (special case!) + - `scale(memory="6GB")` with 4GB/spec → 1 spec = 4GB (not 8GB) + - `scale(cores=10)` with 4 cores/spec → 2 specs = 8 cores (not 12) + + **Why conservative?** + - User expectation: "at most N" not "at least N" + - Safety: prevents OOM, CPU oversubscription + - Consistency: all parameters use same rounding + + Examples + -------- + Regular workers (1 worker per spec): + >>> cluster.scale(5) # Creates 5 specs = 5 workers + + Grouped workers (2 workers per spec): + >>> cluster.scale(5) # Creates 2 specs = 4 workers (conservative) + >>> cluster.scale(6) # Creates 3 specs = 6 workers (exact match) + >>> cluster.scale(memory="6GB") # With 4GB/spec: creates 1 spec = 4GB + """ if memory is not None: - n = max(n, int(math.ceil(parse_bytes(memory) / self._memory_per_worker()))) + # For grouped workers, scale by complete specs to avoid exceeding limit + # Use floor division to be conservative (never exceed requested memory) + if self.new_spec and "group" in self.new_spec: + memory_per_spec = self._memory_per_worker() * self._workers_per_spec( + self.new_spec + ) + target_specs = int(parse_bytes(memory) // memory_per_spec) + n = max(n, target_specs * self._workers_per_spec(self.new_spec)) + else: + # Regular workers: use ceiling (match old behavior) + n = max( + n, int(math.ceil(parse_bytes(memory) / self._memory_per_worker())) + ) if cores is not None: - n = max(n, int(math.ceil(cores / self._threads_per_worker()))) + # For grouped workers, scale by complete specs to avoid exceeding limit + # Use floor division to be conservative (never exceed requested cores) + if self.new_spec and "group" in self.new_spec: + cores_per_spec = self._threads_per_worker() * self._workers_per_spec( + self.new_spec + ) + target_specs = int(cores // cores_per_spec) + n = max(n, target_specs * self._workers_per_spec(self.new_spec)) + else: + # Regular workers: use ceiling (match old behavior) + n = max(n, int(math.ceil(cores / self._threads_per_worker()))) - if len(self.worker_spec) > n: - not_yet_launched = set(self.worker_spec) - { - v["name"] for v in self.scheduler_info["workers"].values() - } - while len(self.worker_spec) > n and not_yet_launched: - del self.worker_spec[not_yet_launched.pop()] + # n is the target number of workers (not specs) + # For grouped workers, we need to scale by specs, where each spec creates multiple workers - while len(self.worker_spec) > n: - self.worker_spec.popitem() + current_worker_count = self._count_workers_in_specs() + # Scale down if we have too many workers + if current_worker_count > n: + # Build set of launched spec names by mapping worker names back to spec names + scheduler_worker_names = { + v["name"] for v in self.scheduler_info["workers"].values() + } + launched_spec_names = set() + for worker_name in scheduler_worker_names: + spec_name = self._worker_name_to_spec_name(worker_name) + if spec_name: + launched_spec_names.add(spec_name) + + not_yet_launched = set(self.worker_spec) - launched_spec_names + + # Remove unlaunched specs first + while current_worker_count > n and not_yet_launched: + spec_name = not_yet_launched.pop() + spec = self.worker_spec[spec_name] + workers_in_spec = self._workers_per_spec(spec) + del self.worker_spec[spec_name] + current_worker_count -= workers_in_spec + + # Remove launched specs if still over target + while current_worker_count > n and self.worker_spec: + spec_name, spec = self.worker_spec.popitem() + workers_in_spec = self._workers_per_spec(spec) + current_worker_count -= workers_in_spec + + # Scale up if we need more workers if self.status not in (Status.closing, Status.closed): - while len(self.worker_spec) < n: - self.worker_spec.update(self.new_worker_spec()) + while current_worker_count < n: + # For grouped workers, check if adding next spec would exceed target + # This ensures we never exceed the requested worker count (conservative scaling) + workers_in_next_spec = ( + self._workers_per_spec(self.new_spec) if self.new_spec else 1 + ) + if current_worker_count + workers_in_next_spec > n: + # Don't add spec if it would exceed target + # Exception: if we have 0 workers and n > 0, add at least one spec + # This ensures we can always create workers when requested (avoids deadlock) + if current_worker_count == 0 and n > 0: + pass # Add the spec even if it exceeds target + else: + break + + new_spec_dict = self.new_worker_spec() + self.worker_spec.update(new_spec_dict) + # Get the spec we just added to count its workers + spec_name = list(new_spec_dict.keys())[0] + spec = new_spec_dict[spec_name] + workers_in_spec = self._workers_per_spec(spec) + current_worker_count += workers_in_spec self.loop.add_callback(self._correct_state) if self.asynchronous: return NoOpAwaitable() - def _new_worker_name(self, worker_number): - """Returns new worker name. + def _new_spec_name(self, spec_number: int) -> str: + """Returns new spec name (key for worker_spec dict). + + This generates a spec name for auto-created worker specs. For regular + workers, the spec name will also be the worker name. For grouped workers, + the spec name is the prefix, and actual worker names will have suffixes + appended (e.g., spec name "0" with group ["-0", "-1"] creates workers + "0-0" and "0-1"). + + This can be overridden in SpecCluster derived classes to customize spec + naming. - This can be overridden in SpecCluster derived classes to customise the - worker names. + Parameters + ---------- + spec_number : int + The numeric identifier for this spec (typically from self._i) + + Returns + ------- + str + The spec name to use as a key in worker_spec dict """ - return worker_number + return str(spec_number) + + def _spec_name_to_worker_names(self, spec_name: str) -> set[str]: + """Convert a spec name to the set of worker names it generates. - def new_worker_spec(self): - """Return name and spec for the next worker + For regular workers, the spec name equals the worker name (1:1 mapping). + For grouped workers, one spec name maps to multiple worker names (1:many). + + Parameters + ---------- + spec_name : str + The spec name (key in worker_spec dict) Returns ------- - d: dict mapping names to worker specs + set[str] + Set of worker names the scheduler will see for this spec + + Examples + -------- + Regular worker (no "group" key): + >>> cluster.worker_spec = {"0": {"cls": Worker, "options": {}}} + >>> cluster._spec_name_to_worker_names("0") + {"0"} + + Grouped worker (has "group" key): + >>> cluster.worker_spec = { + ... "0": {"cls": MultiWorker, "options": {}, "group": ["-0", "-1", "-2"]} + ... } + >>> cluster._spec_name_to_worker_names("0") + {"0-0", "0-1", "0-2"} + """ + if spec_name not in self.worker_spec: + return set() + + spec = self.worker_spec[spec_name] + if "group" in spec: + # Grouped worker: concatenate spec_name with each suffix + return {spec_name + suffix for suffix in spec["group"]} + else: + # Regular worker: spec name == worker name + return {spec_name} + + def _worker_name_to_spec_name(self, worker_name: str) -> str | None: + """Convert a worker name to its corresponding spec name. + + For regular workers, the worker name equals the spec name. + For grouped workers, extract the spec name prefix from the worker name. + + Parameters + ---------- + worker_name : str + The worker name (as seen by the scheduler) + + Returns + ------- + str | None + The spec name (key in worker_spec dict), or None if not found + + Examples + -------- + Regular worker: + >>> cluster.worker_spec = {"0": {"cls": Worker, "options": {}}} + >>> cluster._worker_name_to_spec_name("0") + "0" + + Grouped worker: + >>> cluster.worker_spec = { + ... "0": {"cls": MultiWorker, "options": {}, "group": ["-0", "-1", "-2"]} + ... } + >>> cluster._worker_name_to_spec_name("0-1") + "0" + + Not found: + >>> cluster._worker_name_to_spec_name("nonexistent") + None + """ + # First check if worker_name is directly a spec name (regular worker) + if worker_name in self.worker_spec: + return worker_name + + # For grouped workers, check each spec to see if this worker belongs to it + for spec_name in self.worker_spec: + worker_names = self._spec_name_to_worker_names(spec_name) + if worker_name in worker_names: + return spec_name + + return None + + def new_worker_spec(self) -> dict[str, dict[str, Any]]: + """Return name and spec for the next worker spec + + Returns + ------- + dict[str, dict] + A dictionary with a single entry mapping a spec name (string) to + a worker specification dict See Also -------- scale """ - new_worker_name = self._new_worker_name(self._i) - while new_worker_name in self.worker_spec: + spec_name = self._new_spec_name(self._i) + while spec_name in self.worker_spec: self._i += 1 - new_worker_name = self._new_worker_name(self._i) + spec_name = self._new_spec_name(self._i) - return {new_worker_name: self.new_spec} + return {spec_name: cast(dict[str, Any], self.new_spec)} @property def _supports_scaling(self): return bool(self.new_spec) - async def scale_down(self, workers): - # We may have groups, if so, map worker addresses to job names - if not all(w in self.worker_spec for w in workers): - mapping = {} - for name, spec in self.worker_spec.items(): - if "group" in spec: - for suffix in spec["group"]: - mapping[str(name) + suffix] = name - else: - mapping[name] = name - - workers = {mapping.get(w, w) for w in workers} - - for w in workers: - if w in self.worker_spec: - del self.worker_spec[w] + async def scale_down(self, workers: Iterable[str]) -> None: + """Scale down by removing worker specs. + + Parameters + ---------- + workers : Iterable[str] + Worker names (as seen by the scheduler) to scale down + """ + # Map worker names to spec names (handles both regular and grouped workers) + spec_names_to_remove = set() + for worker_name in workers: + # First check if it's directly a spec name (for backward compatibility) + if worker_name in self.worker_spec: + spec_names_to_remove.add(worker_name) + else: + # Otherwise, map worker name to spec name + spec_name = self._worker_name_to_spec_name(worker_name) + if spec_name: + spec_names_to_remove.add(spec_name) + + for spec_name in spec_names_to_remove: + if spec_name in self.worker_spec: + del self.worker_spec[spec_name] await self scale_up = scale # backwards compatibility diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index ce39d20bb1..8368fe48a9 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -435,7 +435,8 @@ async def test_MultiWorker(): ) as cluster: s = cluster.scheduler async with Client(cluster, asynchronous=True) as client: - cluster.scale(2) + # Scale to 4 workers (2 specs with 2 workers each) + cluster.scale(4) await cluster assert len(cluster.worker_spec) == 2 await client.wait_for_workers(4) @@ -448,20 +449,26 @@ async def test_MultiWorker(): workers_line = re.search("(Workers.+)", cluster._repr_html_()).group(1) assert re.match("Workers.*4", workers_line) - cluster.scale(1) + # Scale to 2 workers (1 spec with 2 workers) + cluster.scale(2) await cluster assert len(s.workers) == 2 + # Scale to 6 GB memory: 4GB per spec, conservatively scales to 1 spec = 4GB + # (rounds DOWN to avoid exceeding 6GB limit) cluster.scale(memory="6GB") await cluster - assert len(cluster.worker_spec) == 2 - assert len(s.workers) == 4 + assert len(cluster.worker_spec) == 1 + assert len(s.workers) == 2 assert cluster.plan == {ws.name for ws in s.workers.values()} + # Scale to 10 cores: 4 cores per spec, conservatively scales to 2 specs = 8 cores + # (rounds DOWN to avoid exceeding 10 cores limit) cluster.scale(cores=10) await cluster - assert len(cluster.workers) == 3 + assert len(cluster.workers) == 2 + # Adaptive with maximum=4 means maximum 4 workers = 2 specs adapt = cluster.adapt(minimum=0, maximum=4) for _ in range(adapt.wait_count): # relax down to 0 workers @@ -469,9 +476,169 @@ async def test_MultiWorker(): await cluster assert not s.workers + # Submit work - adaptive will request workers based on workload future = client.submit(lambda x: x + 1, 10) await future - assert len(cluster.workers) == 1 + # With 2-worker specs and conservative scaling with minimum viability: + # When adaptive requests 1 worker, scale creates 1 spec = 2 workers + # (special case: creates at least 1 spec when scaling from 0) + assert len(cluster.workers) == 1 # 1 spec created + + +@gen_test() +async def test_grouped_worker_death_removes_spec(): + """Test that when a single worker in a group dies, the entire spec is removed.""" + with dask.config.set({"distributed.deploy.lost-worker-timeout": "100ms"}): + async with SpecCluster( + scheduler=scheduler, + worker={ + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 2, "memory_limit": "2 GB"}, + "group": ["-0", "-1"], + }, + asynchronous=True, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + # Scale to 4 workers (2 specs with 2 workers each) + cluster.scale(4) + await cluster + assert len(cluster.worker_spec) == 2 + await client.wait_for_workers(4) + + # Get the spec names + spec_names = list(cluster.worker_spec.keys()) + assert len(spec_names) == 2 + + # Get worker names for the first spec + first_spec_name = spec_names[0] + worker_names = cluster._spec_name_to_worker_names(first_spec_name) + assert len(worker_names) == 2 + + # Kill one worker from the first group + worker_to_kill = list(worker_names)[0] + worker_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == worker_to_kill + ][0] + + # Simulate abrupt worker death (like HPC pre-emption) + await cluster.scheduler.remove_worker( + address=worker_addr, close=False, stimulus_id="test" + ) + + # Wait for lost-worker-timeout + await asyncio.sleep(0.2) + + # The entire spec should be removed (not just the one worker) + assert first_spec_name not in cluster.worker_spec + # The other spec should still exist + assert spec_names[1] in cluster.worker_spec + + +@gen_test() +async def test_grouped_worker_spec_removal_multiple_rounds(): + """Test that spec removal works correctly for multiple rounds with different spec names. + + This test ensures that the spec removal mechanism in _update_worker_status() correctly: + 1. Maps worker names to spec names (e.g., "2-0" -> "2") + 2. Closes and removes the MultiWorker instance from self.workers + 3. Removes the spec from worker_spec + 4. Works for any spec name, not just "0" + + This catches bugs where worker names ("0-0") were incorrectly checked against + self.workers keys (which are spec names like "0"). + """ + with dask.config.set({"distributed.deploy.lost-worker-timeout": "100ms"}): + async with SpecCluster( + scheduler=scheduler, + worker={ + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 2, "memory_limit": "2 GB"}, + "group": ["-0", "-1"], + }, + asynchronous=True, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + # Scale to 4 workers (2 specs) + cluster.scale(4) + await cluster + await client.wait_for_workers(4) + + # Initial state: 2 specs, 2 MultiWorker instances + assert len(cluster.worker_spec) == 2 + assert len(cluster.workers) == 2 + initial_specs = set(cluster.worker_spec.keys()) + + # Round 1: Remove spec "0" + spec_to_remove = "0" + assert spec_to_remove in cluster.worker_spec + assert spec_to_remove in cluster.workers + + worker_names = cluster._spec_name_to_worker_names(spec_to_remove) + worker_to_kill = list(worker_names)[0] + worker_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == worker_to_kill + ][0] + + await cluster.scheduler.remove_worker( + address=worker_addr, close=False, stimulus_id="test-round-1" + ) + await asyncio.sleep(0.2) + + # Verify spec "0" is completely removed + assert spec_to_remove not in cluster.worker_spec + assert ( + spec_to_remove not in cluster.workers + ) # MultiWorker instance removed + assert len(cluster.worker_spec) == 1 + assert len(cluster.workers) == 1 + + # Scale back up to create a new spec (will be "2" since "0" was removed) + cluster.scale(4) + await client.wait_for_workers(4) + + # Should have 2 specs again, but with different names + assert len(cluster.worker_spec) == 2 + assert len(cluster.workers) == 2 + current_specs = set(cluster.worker_spec.keys()) + + # Specs should be "1" and "2" (not "0") + assert "0" not in current_specs + assert "1" in current_specs + assert "2" in current_specs + + # Round 2: Remove spec "2" (this would fail with the old buggy code) + spec_to_remove = "2" + assert spec_to_remove in cluster.worker_spec + assert spec_to_remove in cluster.workers + + worker_names = cluster._spec_name_to_worker_names(spec_to_remove) + worker_to_kill = list(worker_names)[0] + worker_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == worker_to_kill + ][0] + + await cluster.scheduler.remove_worker( + address=worker_addr, close=False, stimulus_id="test-round-2" + ) + await asyncio.sleep(0.2) + + # Verify spec "2" is completely removed + assert spec_to_remove not in cluster.worker_spec + assert ( + spec_to_remove not in cluster.workers + ) # MultiWorker instance removed + assert len(cluster.worker_spec) == 1 + assert len(cluster.workers) == 1 + + # Only spec "1" should remain + assert list(cluster.worker_spec.keys()) == ["1"] + assert list(cluster.workers.keys()) == ["1"] @gen_cluster(client=True, nthreads=[]) @@ -484,30 +651,31 @@ async def test_run_spec(c, s): @gen_test() -async def test_run_spec_cluster_worker_names(): +async def test_run_spec_cluster_custom_spec_names(): + """Test that _new_spec_name() can be overridden to customize spec names""" worker = {"cls": Worker, "options": {"nthreads": 1}} class MyCluster(SpecCluster): - def _new_worker_name(self, worker_number): - return f"prefix-{self.name}-{worker_number}-suffix" + def _new_spec_name(self, spec_number): + return f"prefix-{self.name}-{spec_number}-suffix" async with SpecCluster( asynchronous=True, scheduler=scheduler, worker=worker ) as cluster: cluster.scale(2) await cluster - worker_names = [0, 1] - assert list(cluster.worker_spec) == worker_names - assert sorted(list(cluster.workers)) == worker_names + spec_names = ["0", "1"] + assert list(cluster.worker_spec) == spec_names + assert sorted(list(cluster.workers)) == spec_names async with MyCluster( asynchronous=True, scheduler=scheduler, worker=worker, name="test-name" ) as cluster: - worker_names = ["prefix-test-name-0-suffix", "prefix-test-name-1-suffix"] + spec_names = ["prefix-test-name-0-suffix", "prefix-test-name-1-suffix"] cluster.scale(2) await cluster - assert list(cluster.worker_spec) == worker_names - assert sorted(list(cluster.workers)) == worker_names + assert list(cluster.worker_spec) == spec_names + assert sorted(list(cluster.workers)) == spec_names @gen_test() @@ -544,3 +712,320 @@ async def test_shutdown_scheduler(): assert isinstance(s, Scheduler) assert s.status == Status.closed + + +@gen_test() +async def test_new_spec_name_returns_string(): + """Test that _new_spec_name() returns strings, not integers. + + Spec names (keys in worker_spec dict) should always be strings, whether + auto-generated or user-provided. This ensures type consistency and + eliminates int/str conversion issues throughout the codebase. + """ + async with SpecCluster( + workers={}, scheduler=scheduler, asynchronous=True + ) as cluster: + # Test that _new_spec_name returns a string + name = cluster._new_spec_name(0) + assert isinstance(name, str), f"Expected str, got {type(name).__name__}" + assert name == "0" + + name = cluster._new_spec_name(42) + assert isinstance(name, str), f"Expected str, got {type(name).__name__}" + assert name == "42" + + +@gen_test() +async def test_worker_spec_keys_are_strings(): + """Test that worker_spec keys are strings after scaling. + + When workers are added via scale(), the resulting spec names (keys in + worker_spec dict) should be strings to maintain consistency with + user-provided specs. + """ + worker_template = {"cls": Worker, "options": {"nthreads": 1}} + async with SpecCluster( + workers={}, scheduler=scheduler, worker=worker_template, asynchronous=True + ) as cluster: + # Scale up to create auto-generated worker specs + cluster.scale(2) + await cluster + + # All keys in worker_spec should be strings + for key in cluster.worker_spec.keys(): + assert isinstance( + key, str + ), f"Expected str key, got {type(key).__name__}: {key}" + + +@gen_test() +async def test_spec_name_to_worker_names_regular(): + """Test _spec_name_to_worker_names() with regular (non-grouped) workers.""" + worker_template = {"cls": Worker, "options": {"nthreads": 1}} + async with SpecCluster( + workers={}, scheduler=scheduler, worker=worker_template, asynchronous=True + ) as cluster: + # Scale to create regular workers + cluster.scale(2) + await cluster + + # Regular workers: spec name == worker name (1:1) + assert cluster._spec_name_to_worker_names("0") == {"0"} + assert cluster._spec_name_to_worker_names("1") == {"1"} + + # Non-existent spec returns empty set + assert cluster._spec_name_to_worker_names("nonexistent") == set() + + +@gen_test() +async def test_spec_name_to_worker_names_grouped(): + """Test _spec_name_to_worker_names() with grouped workers.""" + async with SpecCluster( + workers={ + "0": { + "cls": Worker, + "options": {"nthreads": 1}, + "group": ["-0", "-1", "-2"], + }, + "1": { + "cls": Worker, + "options": {"nthreads": 1}, + "group": ["-a", "-b"], + }, + }, + scheduler=scheduler, + asynchronous=True, + ) as cluster: + # Grouped workers: one spec name → multiple worker names + assert cluster._spec_name_to_worker_names("0") == {"0-0", "0-1", "0-2"} + assert cluster._spec_name_to_worker_names("1") == {"1-a", "1-b"} + + +@gen_test() +async def test_worker_name_to_spec_name_regular(): + """Test _worker_name_to_spec_name() with regular (non-grouped) workers.""" + worker_template = {"cls": Worker, "options": {"nthreads": 1}} + async with SpecCluster( + workers={}, scheduler=scheduler, worker=worker_template, asynchronous=True + ) as cluster: + # Scale to create regular workers + cluster.scale(2) + await cluster + + # Regular workers: worker name == spec name + assert cluster._worker_name_to_spec_name("0") == "0" + assert cluster._worker_name_to_spec_name("1") == "1" + + # Non-existent worker returns None + assert cluster._worker_name_to_spec_name("nonexistent") is None + + +@gen_test() +async def test_worker_name_to_spec_name_grouped(): + """Test _worker_name_to_spec_name() with grouped workers.""" + async with SpecCluster( + workers={ + "0": { + "cls": Worker, + "options": {"nthreads": 1}, + "group": ["-0", "-1", "-2"], + }, + "1": { + "cls": Worker, + "options": {"nthreads": 1}, + "group": ["-a", "-b"], + }, + }, + scheduler=scheduler, + asynchronous=True, + ) as cluster: + # All workers from group "0" map back to spec "0" + assert cluster._worker_name_to_spec_name("0-0") == "0" + assert cluster._worker_name_to_spec_name("0-1") == "0" + assert cluster._worker_name_to_spec_name("0-2") == "0" + + # All workers from group "1" map back to spec "1" + assert cluster._worker_name_to_spec_name("1-a") == "1" + assert cluster._worker_name_to_spec_name("1-b") == "1" + + # Non-existent worker returns None + assert cluster._worker_name_to_spec_name("nonexistent") is None + + +@gen_test() +async def test_worker_name_to_spec_name_mixed(): + """Test _worker_name_to_spec_name() with mixed regular and grouped workers.""" + async with SpecCluster( + workers={ + "regular": {"cls": Worker, "options": {"nthreads": 1}}, + "grouped": { + "cls": Worker, + "options": {"nthreads": 1}, + "group": ["-0", "-1"], + }, + }, + scheduler=scheduler, + asynchronous=True, + ) as cluster: + # Regular worker + assert cluster._worker_name_to_spec_name("regular") == "regular" + + # Grouped workers + assert cluster._worker_name_to_spec_name("grouped-0") == "grouped" + assert cluster._worker_name_to_spec_name("grouped-1") == "grouped" + + +@gen_test() +async def test_unexpected_close_whole_worker_group(): + """Test that when all workers in a group die abruptly, the spec is removed and recreated.""" + with dask.config.set({"distributed.deploy.lost-worker-timeout": "100ms"}): + async with SpecCluster( + scheduler=scheduler, + worker={ + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 2, "memory_limit": "2 GB"}, + "group": ["-0", "-1"], + }, + asynchronous=True, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + # Scale to 4 workers (2 specs with 2 workers each) + cluster.scale(4) + await cluster + assert len(cluster.worker_spec) == 2 + await client.wait_for_workers(4) + + # Get the spec names + spec_names = list(cluster.worker_spec.keys()) + assert len(spec_names) == 2 + + # Get all worker names for the first spec + first_spec_name = spec_names[0] + worker_names = cluster._spec_name_to_worker_names(first_spec_name) + assert len(worker_names) == 2 + + # Kill all workers from the first group (simulate HPC job kill) + for worker_name in worker_names: + worker_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == worker_name + ][0] + await cluster.scheduler.remove_worker( + address=worker_addr, close=False, stimulus_id="test" + ) + + # Wait for lost-worker-timeout + await asyncio.sleep(0.2) + + # The entire spec should be removed + assert first_spec_name not in cluster.worker_spec + # The other spec should still exist + assert spec_names[1] in cluster.worker_spec + # Should have 1 spec remaining + assert len(cluster.worker_spec) == 1 + + # With adaptive enabled (minimum=4 workers), the cluster should recreate the missing spec + cluster.adapt(minimum=4, maximum=4) + await client.wait_for_workers(4) + + # Should have 2 specs again (but with a new spec name for the recreated one) + assert len(cluster.worker_spec) == 2 + # Old spec name should not exist + assert first_spec_name not in cluster.worker_spec + # Should have 4 workers total + assert len(cluster.scheduler.workers) == 4 + + +@gen_test() +async def test_scale_down_with_grouped_workers(): + """Test that scale_down correctly maps worker names to spec names.""" + async with SpecCluster( + scheduler=scheduler, + worker={ + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 2, "memory_limit": "2 GB"}, + "group": ["-0", "-1"], + }, + asynchronous=True, + ) as cluster: + # Scale to 4 workers (2 specs with 2 workers each) + cluster.scale(4) + await cluster + assert len(cluster.worker_spec) == 2 + + # Get spec names + spec_names = list(cluster.worker_spec.keys()) + first_spec_name = spec_names[0] + + # Get worker names for the first spec + worker_names = cluster._spec_name_to_worker_names(first_spec_name) + worker_names_list = list(worker_names) + + # Call scale_down with actual worker names (what scheduler knows) + await cluster.scale_down(worker_names_list) + + # The first spec should be removed + assert first_spec_name not in cluster.worker_spec + # The second spec should still exist + assert spec_names[1] in cluster.worker_spec + # Should have 1 spec and 2 workers left + assert len(cluster.worker_spec) == 1 + + +@gen_test() +async def test_mixed_regular_and_grouped_workers(): + """Test cluster with both regular and grouped worker specs.""" + with dask.config.set({"distributed.deploy.lost-worker-timeout": "100ms"}): + async with SpecCluster( + workers={ + "regular-1": {"cls": Worker, "options": {"nthreads": 2}}, + "regular-2": {"cls": Worker, "options": {"nthreads": 2}}, + "grouped": { + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 2}, + "group": ["-0", "-1"], + }, + }, + scheduler=scheduler, + asynchronous=True, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + # Should have 3 specs, 4 workers total (2 regular + 2 grouped) + await client.wait_for_workers(4) + assert len(cluster.worker_spec) == 3 + + # Test regular worker failure - spec should remain + regular_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == "regular-1" + ][0] + await cluster.scheduler.remove_worker( + address=regular_addr, close=False, stimulus_id="test" + ) + await asyncio.sleep(0.2) + + # Regular worker spec should still exist (cluster can recreate it) + assert "regular-1" in cluster.worker_spec + assert len(cluster.worker_spec) == 3 + + # Test grouped worker failure - entire spec should be removed + grouped_worker_names = cluster._spec_name_to_worker_names("grouped") + one_grouped_worker = list(grouped_worker_names)[0] + grouped_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == one_grouped_worker + ][0] + await cluster.scheduler.remove_worker( + address=grouped_addr, close=False, stimulus_id="test" + ) + await asyncio.sleep(0.2) + + # Grouped spec should be removed entirely + assert "grouped" not in cluster.worker_spec + # Regular specs should still exist + assert "regular-1" in cluster.worker_spec + assert "regular-2" in cluster.worker_spec + assert len(cluster.worker_spec) == 2