Skip to content
49 changes: 22 additions & 27 deletions dask_kubernetes/classic/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dask
import pytest
from dask.distributed import Client, wait
from distributed.utils_test import loop, captured_logger # noqa: F401
from distributed.utils_test import captured_logger
from dask.utils import tmpfile

from dask_kubernetes import KubeCluster, make_pod_spec
Expand Down Expand Up @@ -75,17 +75,17 @@ def test_ipython_display(cluster):
sleep(0.5)


def test_env(pod_spec, loop):
with KubeCluster(pod_spec, env={"ABC": "DEF"}, loop=loop) as cluster:
def test_env(pod_spec):
with KubeCluster(pod_spec, env={"ABC": "DEF"}) as cluster:
cluster.scale(1)
with Client(cluster, loop=loop) as client:
with Client(cluster) as client:
while not cluster.scheduler_info["workers"]:
sleep(0.1)
env = client.run(lambda: dict(os.environ))
assert all(v["ABC"] == "DEF" for v in env.values())


def dont_test_pod_template_yaml(docker_image, loop):
def dont_test_pod_template_yaml(docker_image):
test_yaml = {
"kind": "Pod",
"metadata": {"labels": {"app": "dask", "component": "dask-worker"}},
Expand All @@ -109,9 +109,9 @@ def dont_test_pod_template_yaml(docker_image, loop):
with tmpfile(extension="yaml") as fn:
with open(fn, mode="w") as f:
yaml.dump(test_yaml, f)
with KubeCluster(f.name, loop=loop) as cluster:
with KubeCluster(f.name) as cluster:
cluster.scale(2)
with Client(cluster, loop=loop) as client:
with Client(cluster) as client:
future = client.submit(lambda x: x + 1, 10)
result = future.result(timeout=10)
assert result == 11
Expand All @@ -128,7 +128,7 @@ def dont_test_pod_template_yaml(docker_image, loop):
assert all(client.has_what().values())


def test_pod_template_yaml_expand_env_vars(docker_image, loop):
def test_pod_template_yaml_expand_env_vars(docker_image):
try:
os.environ["FOO_IMAGE"] = docker_image

Expand All @@ -155,13 +155,13 @@ def test_pod_template_yaml_expand_env_vars(docker_image, loop):
with tmpfile(extension="yaml") as fn:
with open(fn, mode="w") as f:
yaml.dump(test_yaml, f)
with KubeCluster(f.name, loop=loop) as cluster:
with KubeCluster(f.name) as cluster:
assert cluster.pod_template.spec.containers[0].image == docker_image
finally:
del os.environ["FOO_IMAGE"]


def test_pod_template_dict(docker_image, loop):
def test_pod_template_dict(docker_image):
spec = {
"metadata": {},
"restartPolicy": "Never",
Expand All @@ -185,9 +185,9 @@ def test_pod_template_dict(docker_image, loop):
},
}

with KubeCluster(spec, loop=loop) as cluster:
with KubeCluster(spec) as cluster:
cluster.scale(2)
with Client(cluster, loop=loop) as client:
with Client(cluster) as client:
future = client.submit(lambda x: x + 1, 10)
result = future.result()
assert result == 11
Expand All @@ -202,7 +202,7 @@ def test_pod_template_dict(docker_image, loop):
assert all(client.has_what().values())


def test_pod_template_minimal_dict(docker_image, loop):
def test_pod_template_minimal_dict(docker_image):
spec = {
"spec": {
"containers": [
Expand All @@ -224,9 +224,9 @@ def test_pod_template_minimal_dict(docker_image, loop):
}
}

with KubeCluster(spec, loop=loop) as cluster:
with KubeCluster(spec) as cluster:
cluster.adapt()
with Client(cluster, loop=loop) as client:
with Client(cluster) as client:
future = client.submit(lambda x: x + 1, 10)
result = future.result()
assert result == 11
Expand Down Expand Up @@ -264,9 +264,9 @@ def test_bad_args():
KubeCluster({"kind": "Pod"})


def test_constructor_parameters(pod_spec, loop):
def test_constructor_parameters(pod_spec):
env = {"FOO": "BAR", "A": 1}
with KubeCluster(pod_spec, name="myname", loop=loop, env=env) as cluster:
with KubeCluster(pod_spec, name="myname", env=env) as cluster:
pod = cluster.pod_template

var = [v for v in pod.spec.containers[0].env if v.name == "FOO"]
Expand Down Expand Up @@ -380,15 +380,14 @@ def test_maximum(cluster):
assert "scale beyond maximum number of workers" in result.lower()


def test_extra_pod_config(docker_image, loop):
def test_extra_pod_config(docker_image):
"""
Test that our pod config merging process works fine
"""
with KubeCluster(
make_pod_spec(
docker_image, extra_pod_config={"automountServiceAccountToken": False}
),
loop=loop,
n_workers=0,
) as cluster:

Expand All @@ -397,7 +396,7 @@ def test_extra_pod_config(docker_image, loop):
assert pod.spec.automount_service_account_token is False


def test_extra_container_config(docker_image, loop):
def test_extra_container_config(docker_image):
"""
Test that our container config merging process works fine
"""
Expand All @@ -409,7 +408,6 @@ def test_extra_container_config(docker_image, loop):
"securityContext": {"runAsUser": 0},
},
),
loop=loop,
n_workers=0,
) as cluster:

Expand All @@ -419,15 +417,14 @@ def test_extra_container_config(docker_image, loop):
assert pod.spec.containers[0].security_context == {"runAsUser": 0}


def test_container_resources_config(docker_image, loop):
def test_container_resources_config(docker_image):
"""
Test container resource requests / limits being set properly
"""
with KubeCluster(
make_pod_spec(
docker_image, memory_request="0.5G", memory_limit="1G", cpu_limit="1"
),
loop=loop,
n_workers=0,
) as cluster:

Expand All @@ -439,7 +436,7 @@ def test_container_resources_config(docker_image, loop):
assert "cpu" not in pod.spec.containers[0].resources.requests


def test_extra_container_config_merge(docker_image, loop):
def test_extra_container_config_merge(docker_image):
"""
Test that our container config merging process works recursively fine
"""
Expand All @@ -452,7 +449,6 @@ def test_extra_container_config_merge(docker_image, loop):
"args": ["last-item"],
},
),
loop=loop,
n_workers=0,
) as cluster:

Expand All @@ -464,7 +460,7 @@ def test_extra_container_config_merge(docker_image, loop):
assert pod.spec.containers[0].args[-1] == "last-item"


def test_worker_args(docker_image, loop):
def test_worker_args(docker_image):
"""
Test that dask-worker arguments are added to the container args
"""
Expand All @@ -474,7 +470,6 @@ def test_worker_args(docker_image, loop):
memory_limit="5000M",
resources="FOO=1 BAR=2",
),
loop=loop,
n_workers=0,
) as cluster:

Expand Down
26 changes: 22 additions & 4 deletions dask_kubernetes/common/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from weakref import finalize

import kubernetes_asyncio as kubernetes
from tornado.iostream import StreamClosedError

from distributed.core import rpc

from .utils import check_dependency

Expand All @@ -15,7 +18,7 @@ async def get_external_address_for_scheduler_service(
service,
port_forward_cluster_ip=None,
service_name_resolution_retries=20,
port_name="comm",
port_name="tcp-comm",
):
"""Take a service object and return the scheduler address."""
[port] = [
Expand Down Expand Up @@ -108,7 +111,7 @@ async def port_forward_dashboard(service_name, namespace):
return port


async def get_scheduler_address(service_name, namespace, port_name="comm"):
async def get_scheduler_address(service_name, namespace, port_name="tcp-comm"):
async with kubernetes.client.api_client.ApiClient() as api_client:
api = kubernetes.client.CoreV1Api(api_client)
service = await api.read_namespaced_service(service_name, namespace)
Expand All @@ -132,6 +135,21 @@ async def wait_for_scheduler(cluster_name, namespace):
label_selector=f"dask.org/cluster-name={cluster_name},dask.org/component=scheduler",
timeout_seconds=60,
):
if event["object"].status.phase == "Running":
watch.stop()
if event["object"].status.conditions:
conditions = {
c.type: c.status for c in event["object"].status.conditions
}
if "Ready" in conditions and conditions["Ready"] == "True":
watch.stop()
await asyncio.sleep(0.1)


async def wait_for_scheduler_comm(address):
while True:
try:
async with rpc(address) as scheduler_comm:
await scheduler_comm.versions()
except (StreamClosedError, OSError):
await asyncio.sleep(0.1)
continue
break
50 changes: 29 additions & 21 deletions dask_kubernetes/experimental/kubecluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
)

from dask_kubernetes.common.auth import ClusterAuth
from dask_kubernetes.common.utils import namespace_default
from dask_kubernetes.operator import (
build_cluster_spec,
wait_for_service,
)

from dask_kubernetes.common.networking import (
get_scheduler_address,
port_forward_dashboard,
wait_for_scheduler,
wait_for_scheduler_comm,
)


Expand Down Expand Up @@ -121,7 +122,7 @@ class KubeCluster(Cluster):
def __init__(
self,
name,
namespace="default",
namespace=None,
image="ghcr.io/dask/dask:latest",
n_workers=3,
resources={},
Expand All @@ -133,8 +134,7 @@ def __init__(
**kwargs,
):
self.name = name
# TODO: Set namespace to None and get default namespace from user's context
self.namespace = namespace
self.namespace = namespace or namespace_default()
self.image = image
self.n_workers = n_workers
self.resources = resources
Expand Down Expand Up @@ -208,10 +208,15 @@ async def _create_cluster(self):
) from e
await wait_for_scheduler(cluster_name, self.namespace)
await wait_for_service(core_api, f"{cluster_name}-service", self.namespace)
self.scheduler_comm = rpc(await self._get_scheduler_address())
self.forwarded_dashboard_port = await port_forward_dashboard(
f"{self.name}-cluster-service", self.namespace
scheduler_address = await self._get_scheduler_address()
await wait_for_scheduler_comm(scheduler_address)
self.scheduler_comm = rpc(scheduler_address)
dashboard_address = await get_scheduler_address(
f"{self.name}-cluster-service",
self.namespace,
port_name="http-dashboard",
)
self.forwarded_dashboard_port = dashboard_address.split(":")[-1]

async def _connect_cluster(self):
if self.shutdown_on_close is None:
Expand All @@ -230,10 +235,15 @@ async def _connect_cluster(self):
service_name = f'{cluster_spec["metadata"]["name"]}-service'
await wait_for_scheduler(self.cluster_name, self.namespace)
await wait_for_service(core_api, service_name, self.namespace)
self.scheduler_comm = rpc(await self._get_scheduler_address())
self.forwarded_dashboard_port = await port_forward_dashboard(
f"{self.name}-cluster-service", self.namespace
scheduler_address = await self._get_scheduler_address()
await wait_for_scheduler_comm(scheduler_address)
self.scheduler_comm = rpc(scheduler_address)
dashboard_address = await get_scheduler_address(
service_name,
self.namespace,
port_name="http-dashboard",
)
self.forwarded_dashboard_port = dashboard_address.split(":")[-1]

async def _get_cluster(self):
async with kubernetes.client.api_client.ApiClient() as api_client:
Expand Down Expand Up @@ -465,30 +475,28 @@ def _build_scheduler_spec(self, cluster_name):
{
"name": "scheduler",
"image": self.image,
"args": [
"dask-scheduler",
],
"args": ["dask-scheduler", "--host", "0.0.0.0"],
"env": env,
"resources": self.resources,
"ports": [
{
"name": "comm",
"name": "tcp-comm",
"containerPort": 8786,
"protocol": "TCP",
},
{
"name": "dashboard",
"name": "http-dashboard",
"containerPort": 8787,
"protocol": "TCP",
},
],
"readinessProbe": {
"tcpSocket": {"port": "comm"},
"httpGet": {"port": "http-dashboard", "path": "/health"},
"initialDelaySeconds": 5,
"periodSeconds": 10,
},
"livenessProbe": {
"tcpSocket": {"port": "comm"},
"httpGet": {"port": "http-dashboard", "path": "/health"},
"initialDelaySeconds": 15,
"periodSeconds": 20,
},
Expand All @@ -503,16 +511,16 @@ def _build_scheduler_spec(self, cluster_name):
},
"ports": [
{
"name": "comm",
"name": "tcp-comm",
"protocol": "TCP",
"port": 8786,
"targetPort": "comm",
"targetPort": "tcp-comm",
},
{
"name": "dashboard",
"name": "http-dashboard",
"protocol": "TCP",
"port": 8787,
"targetPort": "dashboard",
"targetPort": "http-dashboard",
},
],
},
Expand Down
4 changes: 2 additions & 2 deletions dask_kubernetes/kubernetes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ kubernetes:
dask.org/cluster-name: "" # Cluster name will be added automatically
dask.org/component: scheduler
ports:
- name: comm
- name: tcp-comm
protocol: TCP
port: 8786
targetPort: 8786
- name: dashboard
- name: http-dashboard
protocol: TCP
port: 8787
targetPort: 8787
Expand Down
Loading