diff --git a/dask_kubernetes/operator/controller/controller.py b/dask_kubernetes/operator/controller/controller.py index 05ce0900..fa58bcc6 100644 --- a/dask_kubernetes/operator/controller/controller.py +++ b/dask_kubernetes/operator/controller/controller.py @@ -40,9 +40,7 @@ KUBERNETES_DATETIME_FORMAT: Final[str] = "%Y-%m-%dT%H:%M:%SZ" -DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: Final[ - str -] = "kubernetes.dask.org/cooldown-until" +DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: Final[str] = "kubernetes.dask.org/cooldown-until" # Load operator plugins from other packages PLUGINS: list[Any] = [] @@ -59,10 +57,7 @@ def _get_annotations(meta: kopf.Meta) -> dict[str, str]: return { annotation_key: annotation_value for annotation_key, annotation_value in meta.annotations.items() - if not any( - annotation_key.startswith(namespace) - for namespace in _ANNOTATION_NAMESPACES_TO_IGNORE - ) + if not any(annotation_key.startswith(namespace) for namespace in _ANNOTATION_NAMESPACES_TO_IGNORE) } @@ -70,9 +65,7 @@ def _get_labels(meta: kopf.Meta) -> dict[str, str]: return { label_key: label_value for label_key, label_value in meta.labels.items() - if not any( - label_key.startswith(namespace) for namespace in _LABEL_NAMESPACES_TO_IGNORE - ) + if not any(label_key.startswith(namespace) for namespace in _LABEL_NAMESPACES_TO_IGNORE) } @@ -351,21 +344,15 @@ async def daskcluster_create_components( annotations.update(**scheduler_spec["metadata"]["annotations"]) if "labels" in scheduler_spec["metadata"]: labels.update(**scheduler_spec["metadata"]["labels"]) - data = build_scheduler_deployment_spec( - name, scheduler_spec.get("spec"), annotations, labels - ) + data = build_scheduler_deployment_spec(name, scheduler_spec.get("spec"), annotations, labels) kopf.adopt(data) scheduler_deployment = await Deployment(data, namespace=namespace) if not await scheduler_deployment.exists(): await scheduler_deployment.create() - logger.info( - f"Scheduler deployment {scheduler_deployment.name} created in {namespace}." - ) + logger.info(f"Scheduler deployment {scheduler_deployment.name} created in {namespace}.") # Create scheduler service - data = build_scheduler_service_spec( - name, scheduler_spec.get("service"), annotations, labels - ) + data = build_scheduler_service_spec(name, scheduler_spec.get("service"), annotations, labels) kopf.adopt(data) scheduler_service = await Service(data, namespace=namespace) if not await scheduler_service.exists(): @@ -389,6 +376,92 @@ async def daskcluster_create_components( patch.status["phase"] = "Pending" +@kopf.on.update("daskcluster.kubernetes.dask.org") +async def daskcluster_update( + spec: kopf.Spec, + status: kopf.Status, + meta: kopf.Meta, + name: str | None, + namespace: str | None, + diff: kopf.Diff, + patch: kopf.Patch, + logger: kopf.Logger, + **__: Any +): + """When the DaskCluster resource is updated update all the components.""" + assert name + assert namespace + logger.info(f"Handling update for DaskCluster '{name}'") + + scheduler_changed = any(op['path'].startswith('/spec/scheduler') for op in diff) + worker_changed = any(op['path'].startswith('/spec/worker') for op in diff) + + base_annotations = _get_annotations(meta) + base_labels = _get_labels(meta) + + if scheduler_changed: + logger.info("Scheduler spec changed, reconciling scheduler components.") + scheduler_spec_part = spec.get("scheduler", {}) + + scheduler_annotations = base_annotations.copy() + scheduler_labels = base_labels.copy() + if "metadata" in scheduler_spec_part: + scheduler_annotations.update(scheduler_spec_part.get("metadata", {}).get("annotations", {})) + scheduler_labels.update(scheduler_spec_part.get("metadata", {}).get("labels", {})) + + desired_dep_spec = build_scheduler_deployment_spec( + name, scheduler_spec_part.get("spec"), scheduler_annotations, scheduler_labels + ) + scheduler_deployment = await Deployment( + SCHEDULER_NAME_TEMPLATE.format(cluster_name=name), namespace=namespace # Use name + ) + if await scheduler_deployment.exists(): + await scheduler_deployment.patch(desired_dep_spec) + logger.info(f"Scheduler deployment {scheduler_deployment.name} patched.") + else: + logger.warning(f"Scheduler deployment {scheduler_deployment.name} not found. Recreating.") + kopf.adopt(desired_dep_spec, owner=meta) + await scheduler_deployment.create(desired_dep_spec) + + desired_svc_spec = build_scheduler_service_spec( + name, scheduler_spec_part.get("service"), scheduler_annotations, scheduler_labels + ) + scheduler_service = await Service( + SCHEDULER_NAME_TEMPLATE.format(cluster_name=name), namespace=namespace # Use name + ) + if await scheduler_service.exists(): + await scheduler_service.patch(desired_svc_spec) + logger.info(f"Scheduler service {scheduler_service.name} patched.") + else: + logger.warning(f"Scheduler service {scheduler_service.name} not found. Recreating.") + kopf.adopt(desired_svc_spec, owner=meta) + await scheduler_service.create(desired_svc_spec) + + if worker_changed: + logger.info("Worker spec changed, reconciling default worker group.") + worker_spec_part = spec.get("worker", {}) + + worker_annotations = base_annotations.copy() + worker_labels = base_labels.copy() + if "metadata" in worker_spec_part: + worker_annotations.update(worker_spec_part.get("metadata", {}).get("annotations", {})) + worker_labels.update(worker_spec_part.get("metadata", {}).get("labels", {})) + + desired_wg_spec = build_default_worker_group_spec( + name, worker_spec_part, worker_annotations, worker_labels + ) + worker_group = await DaskWorkerGroup.get(f"{name}-default", namespace=namespace) + + if await worker_group.exists(): + await worker_group.patch(desired_wg_spec) + logger.info(f"Worker group {worker_group.name} patched.") + else: + logger.warning(f"Worker group {worker_group.name} not found. Recreating.") + kopf.adopt(desired_wg_spec, owner=meta) + await worker_group.create(desired_wg_spec) + + patch.status["observedGeneration"] = meta.generation + logger.info(f"Update handler finished for DaskCluster '{name}'.") @kopf.on.field("service", field="status", labels={"dask.org/component": "scheduler"}) async def handle_scheduler_service_status( @@ -400,23 +473,17 @@ async def handle_scheduler_service_status( ) -> None: assert namespace # If the Service is a LoadBalancer with no ingress endpoints mark the cluster as Pending - if spec["type"] == "LoadBalancer" and not len( - status.get("loadBalancer", {}).get("ingress", []) - ): + if spec["type"] == "LoadBalancer" and not len(status.get("loadBalancer", {}).get("ingress", [])): phase = "Pending" # Otherwise mark it as Running else: phase = "Running" - cluster = await DaskCluster.get( - labels["dask.org/cluster-name"], namespace=namespace - ) + cluster = await DaskCluster.get(labels["dask.org/cluster-name"], namespace=namespace) await cluster.patch({"status": {"phase": phase}}) @kopf.on.create("daskworkergroup.kubernetes.dask.org") -async def daskworkergroup_create( - body: kopf.Body, namespace: str | None, logger: kopf.Logger, **kwargs: Any -) -> None: +async def daskworkergroup_create(body: kopf.Body, namespace: str | None, logger: kopf.Logger, **kwargs: Any) -> None: assert namespace wg = await DaskWorkerGroup(body, namespace=namespace) cluster = await wg.cluster() @@ -463,9 +530,7 @@ async def retire_workers( ) # Otherwise try gracefully retiring via the RPC - logger.debug( - f"Scaling {worker_group_name} failed via the HTTP API, falling back to the Dask RPC" - ) + logger.debug(f"Scaling {worker_group_name} failed via the HTTP API, falling back to the Dask RPC") # Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways with suppress(Exception): comm_address = await get_scheduler_address( @@ -499,9 +564,7 @@ def retire_workers_lifo(workers, n_workers: int) -> list[str]: return [w.name for w in workers[-n_workers:]] -async def check_scheduler_idle( - scheduler_service_name: str, namespace: str | None, logger: kopf.Logger -) -> float: +async def check_scheduler_idle(scheduler_service_name: str, namespace: str | None, logger: kopf.Logger) -> float: assert namespace # Try getting idle time via HTTP API dashboard_address = await get_scheduler_address( @@ -525,9 +588,7 @@ async def check_scheduler_idle( ) # Otherwise try gracefully checking via the RPC - logger.debug( - f"Checking {scheduler_service_name} idleness failed via the HTTP API, falling back to the Dask RPC" - ) + logger.debug(f"Checking {scheduler_service_name} idleness failed via the HTTP API, falling back to the Dask RPC") # Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways with suppress(Exception): comm_address = await get_scheduler_address( @@ -573,9 +634,7 @@ def idle_since_func(dask_scheduler: Scheduler) -> float: return float(idle_since) -async def get_desired_workers( - scheduler_service_name: str, namespace: str | None -) -> Any: +async def get_desired_workers(scheduler_service_name: str, namespace: str | None) -> Any: assert namespace # Try gracefully retiring via the HTTP API dashboard_address = await get_scheduler_address( @@ -602,9 +661,7 @@ async def get_desired_workers( async with rpc(comm_address) as scheduler_comm: return await scheduler_comm.adaptive_target() except Exception as e: - raise SchedulerCommError( - "Unable to get number of desired workers from scheduler" - ) from e + raise SchedulerCommError("Unable to get number of desired workers from scheduler") from e worker_group_scale_locks: dict[str, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) @@ -669,13 +726,9 @@ async def daskworkergroup_replica_update( if "labels" in worker_spec["metadata"]: labels.update(**worker_spec["metadata"]["labels"]) - batch_size = int( - dask.config.get("kubernetes.controller.worker-allocation.batch-size") or 0 - ) + batch_size = int(dask.config.get("kubernetes.controller.worker-allocation.batch-size") or 0) batch_size = min(workers_needed, batch_size) if batch_size else workers_needed - batch_delay = int( - dask.config.get("kubernetes.controller.worker-allocation.delay") or 0 - ) + batch_delay = int(dask.config.get("kubernetes.controller.worker-allocation.delay") or 0) if workers_needed > 0: for _ in range(batch_size): data = build_worker_deployment_spec( @@ -701,9 +754,7 @@ async def daskworkergroup_replica_update( if workers_needed < 0: worker_ids = await retire_workers( n_workers=-workers_needed, - scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format( - cluster_name=cluster_name - ), + scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(cluster_name=cluster_name), worker_group_name=name, namespace=namespace, logger=logger, @@ -712,15 +763,11 @@ async def daskworkergroup_replica_update( for wid in worker_ids: worker_deployment = await Deployment(wid, namespace=namespace) await worker_deployment.delete() - logger.info( - f"Scaled worker group {name} down to {desired_workers} workers." - ) + logger.info(f"Scaled worker group {name} down to {desired_workers} workers.") @kopf.on.delete("daskworkergroup.kubernetes.dask.org", optional=True) -async def daskworkergroup_remove( - name: str | None, namespace: str | None, **__: Any -) -> None: +async def daskworkergroup_remove(name: str | None, namespace: str | None, **__: Any) -> None: assert name assert namespace lock_key = f"{name}/{namespace}" @@ -742,9 +789,7 @@ async def daskjob_create( patch.status["jobStatus"] = "JobCreated" -@kopf.on.field( - "daskjob.kubernetes.dask.org", field="status.jobStatus", new="JobCreated" -) +@kopf.on.field("daskjob.kubernetes.dask.org", field="status.jobStatus", new="JobCreated") async def daskjob_create_components( spec: kopf.Spec, name: str | None, @@ -776,9 +821,7 @@ async def daskjob_create_components( kopf.adopt(cluster_spec) cluster = await DaskCluster(cluster_spec, namespace=namespace) await cluster.create() - logger.info( - f"Cluster {cluster_spec['metadata']['name']} for job {name} created in {namespace}." - ) + logger.info(f"Cluster {cluster_spec['metadata']['name']} for job {name} created in {namespace}.") labels = _get_labels(meta) annotations = _get_annotations(meta) @@ -881,9 +924,7 @@ async def handle_runner_status_change_failed( @kopf.on.create("daskautoscaler.kubernetes.dask.org") -async def daskautoscaler_create( - body: kopf.Body, logger: kopf.Logger, **__: Any -) -> None: +async def daskautoscaler_create(body: kopf.Body, logger: kopf.Logger, **__: Any) -> None: """When an autoscaler is created make it a child of the associated cluster for cascade deletion.""" autoscaler = await DaskAutoscaler(body) cluster = await autoscaler.cluster() @@ -916,16 +957,10 @@ async def daskautoscaler_adapt( return autoscaler = await DaskAutoscaler.get(name, namespace=namespace) - worker_group = await DaskWorkerGroup.get( - f"{spec['cluster']}-default", namespace=namespace - ) + worker_group = await DaskWorkerGroup.get(f"{spec['cluster']}-default", namespace=namespace) current_replicas = worker_group.replicas - cooldown_until = float( - autoscaler.annotations.get( - DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION, time.time() - ) - ) + cooldown_until = float(autoscaler.annotations.get(DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION, time.time())) # Cooldown autoscaling to prevent thrashing if time.time() < cooldown_until: @@ -957,9 +992,7 @@ async def daskautoscaler_adapt( cooldown_until = time.time() + 15 - await autoscaler.annotate( - {DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: str(cooldown_until)} - ) + await autoscaler.annotate({DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: str(cooldown_until)}) logger.info( "Autoscaler updated %s worker count from %d to %d", @@ -968,9 +1001,7 @@ async def daskautoscaler_adapt( desired_workers, ) else: - logger.debug( - "Not autoscaling %s with %d workers", spec["cluster"], current_replicas - ) + logger.debug("Not autoscaling %s with %d workers", spec["cluster"], current_replicas) @kopf.timer("daskcluster.kubernetes.dask.org", interval=5.0) @@ -990,9 +1021,7 @@ async def daskcluster_autoshutdown( logger=logger, ) except Exception: # TODO: Not use broad "Exception" catch here - logger.warning( - "Unable to connect to scheduler, skipping autoshutdown check." - ) + logger.warning("Unable to connect to scheduler, skipping autoshutdown check.") return if idle_since and time.time() > idle_since + idle_timeout: cluster = await DaskCluster.get(name, namespace=namespace)