Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 113 additions & 84 deletions dask_kubernetes/operator/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@

KUBERNETES_DATETIME_FORMAT: Final[str] = "%Y-%m-%dT%H:%M:%SZ"

DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: Final[
str
] = "kubernetes.dask.org/cooldown-until"
DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: Final[str] = "kubernetes.dask.org/cooldown-until"

# Load operator plugins from other packages
PLUGINS: list[Any] = []
Expand All @@ -59,20 +57,15 @@ def _get_annotations(meta: kopf.Meta) -> dict[str, str]:
return {
annotation_key: annotation_value
for annotation_key, annotation_value in meta.annotations.items()
if not any(
annotation_key.startswith(namespace)
for namespace in _ANNOTATION_NAMESPACES_TO_IGNORE
)
if not any(annotation_key.startswith(namespace) for namespace in _ANNOTATION_NAMESPACES_TO_IGNORE)
}


def _get_labels(meta: kopf.Meta) -> dict[str, str]:
return {
label_key: label_value
for label_key, label_value in meta.labels.items()
if not any(
label_key.startswith(namespace) for namespace in _LABEL_NAMESPACES_TO_IGNORE
)
if not any(label_key.startswith(namespace) for namespace in _LABEL_NAMESPACES_TO_IGNORE)
}


Expand Down Expand Up @@ -351,21 +344,15 @@ async def daskcluster_create_components(
annotations.update(**scheduler_spec["metadata"]["annotations"])
if "labels" in scheduler_spec["metadata"]:
labels.update(**scheduler_spec["metadata"]["labels"])
data = build_scheduler_deployment_spec(
name, scheduler_spec.get("spec"), annotations, labels
)
data = build_scheduler_deployment_spec(name, scheduler_spec.get("spec"), annotations, labels)
kopf.adopt(data)
scheduler_deployment = await Deployment(data, namespace=namespace)
if not await scheduler_deployment.exists():
await scheduler_deployment.create()
logger.info(
f"Scheduler deployment {scheduler_deployment.name} created in {namespace}."
)
logger.info(f"Scheduler deployment {scheduler_deployment.name} created in {namespace}.")

# Create scheduler service
data = build_scheduler_service_spec(
name, scheduler_spec.get("service"), annotations, labels
)
data = build_scheduler_service_spec(name, scheduler_spec.get("service"), annotations, labels)
kopf.adopt(data)
scheduler_service = await Service(data, namespace=namespace)
if not await scheduler_service.exists():
Expand All @@ -389,6 +376,92 @@ async def daskcluster_create_components(

patch.status["phase"] = "Pending"

@kopf.on.update("daskcluster.kubernetes.dask.org")
async def daskcluster_update(
spec: kopf.Spec,
status: kopf.Status,
meta: kopf.Meta,
name: str | None,
namespace: str | None,
diff: kopf.Diff,
patch: kopf.Patch,
logger: kopf.Logger,
**__: Any
):
"""When the DaskCluster resource is updated update all the components."""
assert name
assert namespace
Comment on lines +392 to +393
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If either of these assertions fail then kopf will retry automatically forever. We probably need to catch the assertion error and just return if this isn't what we want to hapen.

logger.info(f"Handling update for DaskCluster '{name}'")

scheduler_changed = any(op['path'].startswith('/spec/scheduler') for op in diff)
worker_changed = any(op['path'].startswith('/spec/worker') for op in diff)

base_annotations = _get_annotations(meta)
base_labels = _get_labels(meta)

if scheduler_changed:
logger.info("Scheduler spec changed, reconciling scheduler components.")
scheduler_spec_part = spec.get("scheduler", {})

scheduler_annotations = base_annotations.copy()
scheduler_labels = base_labels.copy()
if "metadata" in scheduler_spec_part:
scheduler_annotations.update(scheduler_spec_part.get("metadata", {}).get("annotations", {}))
scheduler_labels.update(scheduler_spec_part.get("metadata", {}).get("labels", {}))

desired_dep_spec = build_scheduler_deployment_spec(
name, scheduler_spec_part.get("spec"), scheduler_annotations, scheduler_labels
)
scheduler_deployment = await Deployment(
SCHEDULER_NAME_TEMPLATE.format(cluster_name=name), namespace=namespace # Use name
)
if await scheduler_deployment.exists():
await scheduler_deployment.patch(desired_dep_spec)
logger.info(f"Scheduler deployment {scheduler_deployment.name} patched.")
else:
logger.warning(f"Scheduler deployment {scheduler_deployment.name} not found. Recreating.")
kopf.adopt(desired_dep_spec, owner=meta)
await scheduler_deployment.create(desired_dep_spec)

desired_svc_spec = build_scheduler_service_spec(
name, scheduler_spec_part.get("service"), scheduler_annotations, scheduler_labels
)
scheduler_service = await Service(
SCHEDULER_NAME_TEMPLATE.format(cluster_name=name), namespace=namespace # Use name
)
if await scheduler_service.exists():
await scheduler_service.patch(desired_svc_spec)
logger.info(f"Scheduler service {scheduler_service.name} patched.")
else:
logger.warning(f"Scheduler service {scheduler_service.name} not found. Recreating.")
kopf.adopt(desired_svc_spec, owner=meta)
await scheduler_service.create(desired_svc_spec)

if worker_changed:
logger.info("Worker spec changed, reconciling default worker group.")
worker_spec_part = spec.get("worker", {})

worker_annotations = base_annotations.copy()
worker_labels = base_labels.copy()
if "metadata" in worker_spec_part:
worker_annotations.update(worker_spec_part.get("metadata", {}).get("annotations", {}))
worker_labels.update(worker_spec_part.get("metadata", {}).get("labels", {}))

desired_wg_spec = build_default_worker_group_spec(
name, worker_spec_part, worker_annotations, worker_labels
)
worker_group = await DaskWorkerGroup.get(f"{name}-default", namespace=namespace)

if await worker_group.exists():
await worker_group.patch(desired_wg_spec)
logger.info(f"Worker group {worker_group.name} patched.")
else:
logger.warning(f"Worker group {worker_group.name} not found. Recreating.")
kopf.adopt(desired_wg_spec, owner=meta)
await worker_group.create(desired_wg_spec)

patch.status["observedGeneration"] = meta.generation
logger.info(f"Update handler finished for DaskCluster '{name}'.")

@kopf.on.field("service", field="status", labels={"dask.org/component": "scheduler"})
async def handle_scheduler_service_status(
Expand All @@ -400,23 +473,17 @@ async def handle_scheduler_service_status(
) -> None:
assert namespace
# If the Service is a LoadBalancer with no ingress endpoints mark the cluster as Pending
if spec["type"] == "LoadBalancer" and not len(
status.get("loadBalancer", {}).get("ingress", [])
):
if spec["type"] == "LoadBalancer" and not len(status.get("loadBalancer", {}).get("ingress", [])):
phase = "Pending"
# Otherwise mark it as Running
else:
phase = "Running"
cluster = await DaskCluster.get(
labels["dask.org/cluster-name"], namespace=namespace
)
cluster = await DaskCluster.get(labels["dask.org/cluster-name"], namespace=namespace)
await cluster.patch({"status": {"phase": phase}})


@kopf.on.create("daskworkergroup.kubernetes.dask.org")
async def daskworkergroup_create(
body: kopf.Body, namespace: str | None, logger: kopf.Logger, **kwargs: Any
) -> None:
async def daskworkergroup_create(body: kopf.Body, namespace: str | None, logger: kopf.Logger, **kwargs: Any) -> None:
assert namespace
wg = await DaskWorkerGroup(body, namespace=namespace)
cluster = await wg.cluster()
Expand Down Expand Up @@ -463,9 +530,7 @@ async def retire_workers(
)

# Otherwise try gracefully retiring via the RPC
logger.debug(
f"Scaling {worker_group_name} failed via the HTTP API, falling back to the Dask RPC"
)
logger.debug(f"Scaling {worker_group_name} failed via the HTTP API, falling back to the Dask RPC")
# Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways
with suppress(Exception):
comm_address = await get_scheduler_address(
Expand Down Expand Up @@ -499,9 +564,7 @@ def retire_workers_lifo(workers, n_workers: int) -> list[str]:
return [w.name for w in workers[-n_workers:]]


async def check_scheduler_idle(
scheduler_service_name: str, namespace: str | None, logger: kopf.Logger
) -> float:
async def check_scheduler_idle(scheduler_service_name: str, namespace: str | None, logger: kopf.Logger) -> float:
assert namespace
# Try getting idle time via HTTP API
dashboard_address = await get_scheduler_address(
Expand All @@ -525,9 +588,7 @@ async def check_scheduler_idle(
)

# Otherwise try gracefully checking via the RPC
logger.debug(
f"Checking {scheduler_service_name} idleness failed via the HTTP API, falling back to the Dask RPC"
)
logger.debug(f"Checking {scheduler_service_name} idleness failed via the HTTP API, falling back to the Dask RPC")
# Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways
with suppress(Exception):
comm_address = await get_scheduler_address(
Expand Down Expand Up @@ -573,9 +634,7 @@ def idle_since_func(dask_scheduler: Scheduler) -> float:
return float(idle_since)


async def get_desired_workers(
scheduler_service_name: str, namespace: str | None
) -> Any:
async def get_desired_workers(scheduler_service_name: str, namespace: str | None) -> Any:
assert namespace
# Try gracefully retiring via the HTTP API
dashboard_address = await get_scheduler_address(
Expand All @@ -602,9 +661,7 @@ async def get_desired_workers(
async with rpc(comm_address) as scheduler_comm:
return await scheduler_comm.adaptive_target()
except Exception as e:
raise SchedulerCommError(
"Unable to get number of desired workers from scheduler"
) from e
raise SchedulerCommError("Unable to get number of desired workers from scheduler") from e


worker_group_scale_locks: dict[str, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
Expand Down Expand Up @@ -669,13 +726,9 @@ async def daskworkergroup_replica_update(
if "labels" in worker_spec["metadata"]:
labels.update(**worker_spec["metadata"]["labels"])

batch_size = int(
dask.config.get("kubernetes.controller.worker-allocation.batch-size") or 0
)
batch_size = int(dask.config.get("kubernetes.controller.worker-allocation.batch-size") or 0)
batch_size = min(workers_needed, batch_size) if batch_size else workers_needed
batch_delay = int(
dask.config.get("kubernetes.controller.worker-allocation.delay") or 0
)
batch_delay = int(dask.config.get("kubernetes.controller.worker-allocation.delay") or 0)
if workers_needed > 0:
for _ in range(batch_size):
data = build_worker_deployment_spec(
Expand All @@ -701,9 +754,7 @@ async def daskworkergroup_replica_update(
if workers_needed < 0:
worker_ids = await retire_workers(
n_workers=-workers_needed,
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(
cluster_name=cluster_name
),
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(cluster_name=cluster_name),
worker_group_name=name,
namespace=namespace,
logger=logger,
Expand All @@ -712,15 +763,11 @@ async def daskworkergroup_replica_update(
for wid in worker_ids:
worker_deployment = await Deployment(wid, namespace=namespace)
await worker_deployment.delete()
logger.info(
f"Scaled worker group {name} down to {desired_workers} workers."
)
logger.info(f"Scaled worker group {name} down to {desired_workers} workers.")


@kopf.on.delete("daskworkergroup.kubernetes.dask.org", optional=True)
async def daskworkergroup_remove(
name: str | None, namespace: str | None, **__: Any
) -> None:
async def daskworkergroup_remove(name: str | None, namespace: str | None, **__: Any) -> None:
assert name
assert namespace
lock_key = f"{name}/{namespace}"
Expand All @@ -742,9 +789,7 @@ async def daskjob_create(
patch.status["jobStatus"] = "JobCreated"


@kopf.on.field(
"daskjob.kubernetes.dask.org", field="status.jobStatus", new="JobCreated"
)
@kopf.on.field("daskjob.kubernetes.dask.org", field="status.jobStatus", new="JobCreated")
async def daskjob_create_components(
spec: kopf.Spec,
name: str | None,
Expand Down Expand Up @@ -776,9 +821,7 @@ async def daskjob_create_components(
kopf.adopt(cluster_spec)
cluster = await DaskCluster(cluster_spec, namespace=namespace)
await cluster.create()
logger.info(
f"Cluster {cluster_spec['metadata']['name']} for job {name} created in {namespace}."
)
logger.info(f"Cluster {cluster_spec['metadata']['name']} for job {name} created in {namespace}.")

labels = _get_labels(meta)
annotations = _get_annotations(meta)
Expand Down Expand Up @@ -881,9 +924,7 @@ async def handle_runner_status_change_failed(


@kopf.on.create("daskautoscaler.kubernetes.dask.org")
async def daskautoscaler_create(
body: kopf.Body, logger: kopf.Logger, **__: Any
) -> None:
async def daskautoscaler_create(body: kopf.Body, logger: kopf.Logger, **__: Any) -> None:
"""When an autoscaler is created make it a child of the associated cluster for cascade deletion."""
autoscaler = await DaskAutoscaler(body)
cluster = await autoscaler.cluster()
Expand Down Expand Up @@ -916,16 +957,10 @@ async def daskautoscaler_adapt(
return

autoscaler = await DaskAutoscaler.get(name, namespace=namespace)
worker_group = await DaskWorkerGroup.get(
f"{spec['cluster']}-default", namespace=namespace
)
worker_group = await DaskWorkerGroup.get(f"{spec['cluster']}-default", namespace=namespace)

current_replicas = worker_group.replicas
cooldown_until = float(
autoscaler.annotations.get(
DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION, time.time()
)
)
cooldown_until = float(autoscaler.annotations.get(DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION, time.time()))

# Cooldown autoscaling to prevent thrashing
if time.time() < cooldown_until:
Expand Down Expand Up @@ -957,9 +992,7 @@ async def daskautoscaler_adapt(

cooldown_until = time.time() + 15

await autoscaler.annotate(
{DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: str(cooldown_until)}
)
await autoscaler.annotate({DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: str(cooldown_until)})

logger.info(
"Autoscaler updated %s worker count from %d to %d",
Expand All @@ -968,9 +1001,7 @@ async def daskautoscaler_adapt(
desired_workers,
)
else:
logger.debug(
"Not autoscaling %s with %d workers", spec["cluster"], current_replicas
)
logger.debug("Not autoscaling %s with %d workers", spec["cluster"], current_replicas)


@kopf.timer("daskcluster.kubernetes.dask.org", interval=5.0)
Expand All @@ -990,9 +1021,7 @@ async def daskcluster_autoshutdown(
logger=logger,
)
except Exception: # TODO: Not use broad "Exception" catch here
logger.warning(
"Unable to connect to scheduler, skipping autoshutdown check."
)
logger.warning("Unable to connect to scheduler, skipping autoshutdown check.")
return
if idle_since and time.time() > idle_since + idle_timeout:
cluster = await DaskCluster.get(name, namespace=namespace)
Expand Down
Loading