Skip to content

Commit 93fe171

Browse files
jonded94Jonas Dedden
andauthored
Enable overwrites of default environment variables (#874)
* Enable overwrites of default environment variables * Black formatting * Include test for additional worker group; test overriding of environment variables * Black --------- Co-authored-by: Jonas Dedden <[email protected]>
1 parent b668cc6 commit 93fe171

File tree

3 files changed

+93
-32
lines changed

3 files changed

+93
-32
lines changed

dask_kubernetes/operator/controller/controller.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,21 +153,25 @@ def build_worker_deployment_spec(
153153
"metadata": metadata,
154154
"spec": spec,
155155
}
156-
env = [
157-
{
158-
"name": "DASK_WORKER_NAME",
159-
"value": worker_name,
160-
},
161-
{
162-
"name": "DASK_SCHEDULER_ADDRESS",
163-
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
164-
},
165-
]
156+
worker_env = {
157+
"name": "DASK_WORKER_NAME",
158+
"value": worker_name,
159+
}
160+
scheduler_env = {
161+
"name": "DASK_SCHEDULER_ADDRESS",
162+
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
163+
}
166164
for container in deployment_spec["spec"]["template"]["spec"]["containers"]:
167-
if "env" in container:
168-
container["env"].extend(env)
169-
else:
170-
container["env"] = env
165+
if "env" not in container:
166+
container["env"] = [worker_env, scheduler_env]
167+
continue
168+
169+
container_env_names = [env_item["name"] for env_item in container["env"]]
170+
171+
if "DASK_WORKER_NAME" not in container_env_names:
172+
container["env"].append(worker_env)
173+
if "DASK_SCHEDULER_ADDRESS" not in container_env_names:
174+
container["env"].append(scheduler_env)
171175
return deployment_spec
172176

173177

dask_kubernetes/operator/controller/tests/resources/simpleworkergroup.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ metadata:
55
spec:
66
cluster: simple
77
worker:
8-
replicas: 2
8+
replicas: 1
99
spec:
1010
containers:
1111
- name: worker
@@ -23,3 +23,5 @@ spec:
2323
env:
2424
- name: WORKER_ENV
2525
value: hello-world # We dont test the value, just the name
26+
- name: DASK_WORKER_NAME
27+
value: test-worker

dask_kubernetes/operator/controller/tests/test_controller.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
DIR = pathlib.Path(__file__).parent.absolute()
2222

23-
2423
_EXPECTED_ANNOTATIONS = {"test-annotation": "annotation-value"}
2524
_EXPECTED_LABELS = {"test-label": "label-value"}
2625
DEFAULT_CLUSTER_NAME = "simple"
@@ -47,7 +46,6 @@ def gen_cluster(k8s_cluster, ns, gen_cluster_manifest):
4746

4847
@asynccontextmanager
4948
async def cm(cluster_name=DEFAULT_CLUSTER_NAME):
50-
5149
cluster_path = gen_cluster_manifest(cluster_name)
5250
# Create cluster resource
5351
k8s_cluster.kubectl("apply", "-n", ns, "-f", cluster_path)
@@ -95,6 +93,36 @@ async def cm(job_file):
9593
yield cm
9694

9795

96+
@pytest.fixture()
97+
def gen_worker_group(k8s_cluster, ns):
98+
"""Yields an instantiated context manager for creating/deleting a worker group."""
99+
100+
@asynccontextmanager
101+
async def cm(worker_group_file):
102+
worker_group_path = os.path.join(DIR, "resources", worker_group_file)
103+
with open(worker_group_path) as f:
104+
worker_group_name = yaml.load(f, yaml.Loader)["metadata"]["name"]
105+
106+
# Create cluster resource
107+
k8s_cluster.kubectl("apply", "-n", ns, "-f", worker_group_path)
108+
while worker_group_name not in k8s_cluster.kubectl(
109+
"get", "daskworkergroups.kubernetes.dask.org", "-n", ns
110+
):
111+
await asyncio.sleep(0.1)
112+
113+
try:
114+
yield worker_group_name, ns
115+
finally:
116+
# Test: remove the wait=True, because I think this is blocking the operator
117+
k8s_cluster.kubectl("delete", "-n", ns, "-f", worker_group_path)
118+
while worker_group_name in k8s_cluster.kubectl(
119+
"get", "daskworkergroups.kubernetes.dask.org", "-n", ns
120+
):
121+
await asyncio.sleep(0.1)
122+
123+
yield cm
124+
125+
98126
def test_customresources(k8s_cluster):
99127
assert "daskclusters.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd")
100128
assert "daskworkergroups.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd")
@@ -671,32 +699,59 @@ async def test_object_dask_cluster(k8s_cluster, kopf_runner, gen_cluster):
671699

672700

673701
@pytest.mark.anyio
674-
async def test_object_dask_worker_group(k8s_cluster, kopf_runner, gen_cluster):
702+
async def test_object_dask_worker_group(
703+
k8s_cluster, kopf_runner, gen_cluster, gen_worker_group
704+
):
675705
with kopf_runner:
676-
async with gen_cluster() as (cluster_name, ns):
706+
async with (
707+
gen_cluster() as (cluster_name, ns),
708+
gen_worker_group("simpleworkergroup.yaml") as (
709+
additional_workergroup_name,
710+
_,
711+
),
712+
):
677713
cluster = await DaskCluster.get(cluster_name, namespace=ns)
714+
additional_workergroup = await DaskWorkerGroup.get(
715+
additional_workergroup_name, namespace=ns
716+
)
678717

679718
worker_groups = []
680719
while not worker_groups:
681720
worker_groups = await cluster.worker_groups()
682721
await asyncio.sleep(0.1)
683722
assert len(worker_groups) == 1 # Just the default worker group
684-
wg = worker_groups[0]
685-
assert isinstance(wg, DaskWorkerGroup)
723+
worker_groups = worker_groups + [additional_workergroup]
686724

687-
pods = []
688-
while not pods:
689-
pods = await wg.pods()
690-
await asyncio.sleep(0.1)
691-
assert all([isinstance(p, Pod) for p in pods])
725+
for wg in worker_groups:
726+
assert isinstance(wg, DaskWorkerGroup)
692727

693-
deployments = []
694-
while not deployments:
695-
deployments = await wg.deployments()
696-
await asyncio.sleep(0.1)
697-
assert all([isinstance(d, Deployment) for d in deployments])
728+
deployments = []
729+
while not deployments:
730+
deployments = await wg.deployments()
731+
await asyncio.sleep(0.1)
732+
assert all([isinstance(d, Deployment) for d in deployments])
698733

699-
assert (await wg.cluster()).name == cluster.name
734+
pods = []
735+
while not pods:
736+
pods = await wg.pods()
737+
await asyncio.sleep(0.1)
738+
assert all([isinstance(p, Pod) for p in pods])
739+
740+
assert (await wg.cluster()).name == cluster.name
741+
742+
for deployment in deployments:
743+
assert deployment.labels["dask.org/cluster-name"] == cluster.name
744+
for env in deployment.spec["template"]["spec"]["containers"][0][
745+
"env"
746+
]:
747+
if env["name"] == "DASK_WORKER_NAME":
748+
if wg.name == additional_workergroup_name:
749+
assert env["value"] == "test-worker"
750+
else:
751+
assert env["value"] == deployment.name
752+
if env["name"] == "DASK_SCHEDULER_ADDRESS":
753+
scheduler_service = await cluster.scheduler_service()
754+
assert f"{scheduler_service.name}.{ns}" in env["value"]
700755

701756

702757
@pytest.mark.anyio

0 commit comments

Comments
 (0)