Skip to content
53 changes: 43 additions & 10 deletions dask_kubernetes/operator/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
import kubernetes_asyncio as kubernetes
from importlib_metadata import entry_points
from kubernetes_asyncio.client import ApiException
from kr8s.asyncio.objects import APIObject

from dask_kubernetes.common.auth import ClusterAuth
from dask_kubernetes.common.networking import get_scheduler_address
from dask_kubernetes.aiopykube import HTTPClient, KubeConfig
from dask_kubernetes.aiopykube.dask import DaskCluster
from distributed.core import rpc, clean_exception
from distributed.protocol.pickle import dumps

Expand All @@ -40,6 +39,45 @@ class SchedulerCommError(Exception):
"""Raised when unable to communicate with a scheduler."""


class DaskCluster(APIObject):
version = "kubernetes.dask.org/v1"
endpoint = "daskclusters"
kind = "DaskCluster"
plural = "daskclusters"
singular = "daskcluster"
namespaced = True
scalable = True
scalable_spec = "worker.replicas"


class DaskWorkerGroup(APIObject):
version = "kubernetes.dask.org/v1"
endpoint = "daskworkergroups"
kind = "DaskWorkerGroup"
plural = "daskworkergroups"
singular = "daskworkergroup"
namespaced = True
scalable = True


class DaskAutoscaler(APIObject):
version = "kubernetes.dask.org/v1"
endpoint = "daskautoscalers"
kind = "DaskAutoscaler"
plural = "daskautoscalers"
singular = "daskautoscaler"
namespaced = True


class DaskJob(APIObject):
version = "kubernetes.dask.org/v1"
endpoint = "daskjobs"
kind = "DaskJob"
plural = "daskjobs"
singular = "daskjob"
namespaced = True


def _get_annotations(meta):
return {
annotation_key: annotation_value
Expand Down Expand Up @@ -347,10 +385,8 @@ async def handle_scheduler_service_status(
# Otherwise mark it as Running
else:
phase = "Running"

api = HTTPClient(KubeConfig.from_env())
cluster = await DaskCluster.objects(api, namespace=namespace).get_by_name(
labels["dask.org/cluster-name"]
cluster = await DaskCluster.get(
labels["dask.org/cluster-name"], namespace=namespace
)
await cluster.patch({"status": {"phase": phase}})

Expand Down Expand Up @@ -986,8 +1022,5 @@ async def daskcluster_autoshutdown(spec, name, namespace, logger, **kwargs):
logger.warn("Unable to connect to scheduler, skipping autoshutdown check.")
return
if idle_since and time.time() > idle_since + spec["idleTimeout"]:
api = HTTPClient(KubeConfig.from_env())
cluster = await DaskCluster.objects(api, namespace=namespace).get_by_name(
name
)
cluster = await DaskCluster.get(name, namespace=namespace)
await cluster.delete()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ kubernetes-asyncio>=12.0.1
kopf>=1.35.3
pykube-ng>=22.9.0
rich>=12.5.1
kr8s==0.5.1