diff --git a/dask_kubernetes/operator/controller/controller.py b/dask_kubernetes/operator/controller/controller.py index db47a75a7..7f3c29802 100644 --- a/dask_kubernetes/operator/controller/controller.py +++ b/dask_kubernetes/operator/controller/controller.py @@ -481,6 +481,29 @@ async def get_desired_workers(scheduler_service_name, namespace, logger): worker_group_scale_locks = defaultdict(lambda: asyncio.Lock()) +@kopf.on.field("daskcluster.kubernetes.dask.org", field="spec.worker.replicas") +async def daskcluster_default_worker_group_replica_update( + name, namespace, meta, spec, old, new, body, logger, **kwargs +): + if old is None: + return + worker_group_name = f"{name}-default" + + async with kubernetes.client.api_client.ApiClient() as api_client: + custom_objects_api = kubernetes.client.CustomObjectsApi(api_client) + custom_objects_api.api_client.set_default_header( + "content-type", "application/merge-patch+json" + ) + await custom_objects_api.patch_namespaced_custom_object_scale( + group="kubernetes.dask.org", + version="v1", + plural="daskworkergroups", + namespace=namespace, + name=worker_group_name, + body={"spec": {"replicas": new}}, + ) + + @kopf.on.field("daskworkergroup.kubernetes.dask.org", field="spec.worker.replicas") async def daskworkergroup_replica_update( name, namespace, meta, spec, new, body, logger, **kwargs diff --git a/dask_kubernetes/operator/controller/tests/test_controller.py b/dask_kubernetes/operator/controller/tests/test_controller.py index f0b1c676c..fadc8882c 100644 --- a/dask_kubernetes/operator/controller/tests/test_controller.py +++ b/dask_kubernetes/operator/controller/tests/test_controller.py @@ -140,6 +140,54 @@ async def test_scalesimplecluster(k8s_cluster, kopf_runner, gen_cluster): "daskworkergroup.kubernetes.dask.org", "simple-default", ) + # TODO: Currently, doesn't test anything. Need to add optional + # argument to wait when removing workers once distributed + # PR github.com/dask/distributed/pull/6377 is merged. + await client.wait_for_workers(3) + + +@pytest.mark.asyncio +async def test_scalesimplecluster_from_cluster_spec( + k8s_cluster, kopf_runner, gen_cluster +): + with kopf_runner as runner: + async with gen_cluster() as cluster_name: + scheduler_pod_name = "simple-scheduler" + worker_pod_name = "simple-default-worker" + service_name = "simple-scheduler" + while scheduler_pod_name not in k8s_cluster.kubectl("get", "pods"): + await asyncio.sleep(0.1) + while service_name not in k8s_cluster.kubectl("get", "svc"): + await asyncio.sleep(0.1) + while worker_pod_name not in k8s_cluster.kubectl("get", "pods"): + await asyncio.sleep(0.1) + k8s_cluster.kubectl( + "wait", + "pods", + "--for=condition=Ready", + scheduler_pod_name, + "--timeout=120s", + ) + with k8s_cluster.port_forward(f"service/{service_name}", 8786) as port: + async with Client( + f"tcp://localhost:{port}", asynchronous=True + ) as client: + k8s_cluster.kubectl( + "scale", + "--replicas=5", + "daskcluster.kubernetes.dask.org", + cluster_name, + ) + await client.wait_for_workers(5) + k8s_cluster.kubectl( + "scale", + "--replicas=3", + "daskcluster.kubernetes.dask.org", + cluster_name, + ) + # TODO: Currently, doesn't test anything. Need to add optional + # argument to wait when removing workers once distributed + # PR github.com/dask/distributed/pull/6377 is merged. await client.wait_for_workers(3)