|
11 | 11 | from dask.distributed import Client |
12 | 12 | from kr8s.asyncio.objects import Deployment, Pod, Service |
13 | 13 |
|
| 14 | +from dask_kubernetes.constants import MAX_CLUSTER_NAME_LEN |
14 | 15 | from dask_kubernetes.operator._objects import DaskCluster, DaskJob, DaskWorkerGroup |
15 | 16 | from dask_kubernetes.operator.controller import ( |
16 | 17 | KUBERNETES_DATETIME_FORMAT, |
|
22 | 23 |
|
23 | 24 | _EXPECTED_ANNOTATIONS = {"test-annotation": "annotation-value"} |
24 | 25 | _EXPECTED_LABELS = {"test-label": "label-value"} |
| 26 | +DEFAULT_CLUSTER_NAME = "simple" |
25 | 27 |
|
26 | 28 |
|
27 | 29 | @pytest.fixture() |
28 | | -def gen_cluster(k8s_cluster, ns): |
| 30 | +def gen_cluster_manifest(tmp_path): |
| 31 | + def factory(cluster_name=DEFAULT_CLUSTER_NAME): |
| 32 | + original_manifest_path = os.path.join(DIR, "resources", "simplecluster.yaml") |
| 33 | + with open(original_manifest_path, "r") as original_manifest_file: |
| 34 | + manifest = yaml.safe_load(original_manifest_file) |
| 35 | + |
| 36 | + manifest["metadata"]["name"] = cluster_name |
| 37 | + new_manifest_path = tmp_path / "cluster.yaml" |
| 38 | + new_manifest_path.write_text(yaml.safe_dump(manifest)) |
| 39 | + return tmp_path |
| 40 | + |
| 41 | + return factory |
| 42 | + |
| 43 | + |
| 44 | +@pytest.fixture() |
| 45 | +def gen_cluster(k8s_cluster, ns, gen_cluster_manifest): |
29 | 46 | """Yields an instantiated context manager for creating/deleting a simple cluster.""" |
30 | 47 |
|
31 | 48 | @asynccontextmanager |
32 | | - async def cm(): |
33 | | - cluster_path = os.path.join(DIR, "resources", "simplecluster.yaml") |
34 | | - cluster_name = "simple" |
| 49 | + async def cm(cluster_name=DEFAULT_CLUSTER_NAME): |
35 | 50 |
|
| 51 | + cluster_path = gen_cluster_manifest(cluster_name) |
36 | 52 | # Create cluster resource |
37 | 53 | k8s_cluster.kubectl("apply", "-n", ns, "-f", cluster_path) |
38 | 54 | while cluster_name not in k8s_cluster.kubectl( |
@@ -695,3 +711,42 @@ async def test_object_dask_job(k8s_cluster, kopf_runner, gen_job): |
695 | 711 |
|
696 | 712 | cluster = await job.cluster() |
697 | 713 | assert isinstance(cluster, DaskCluster) |
| 714 | + |
| 715 | + |
| 716 | +async def _get_cluster_status(k8s_cluster, ns, cluster_name): |
| 717 | + """ |
| 718 | + Will loop infinitely in search of non-falsey cluster status. |
| 719 | + Make sure there is a timeout on any test which calls this. |
| 720 | + """ |
| 721 | + while True: |
| 722 | + cluster_status = k8s_cluster.kubectl( |
| 723 | + "get", |
| 724 | + "-n", |
| 725 | + ns, |
| 726 | + "daskcluster.kubernetes.dask.org", |
| 727 | + cluster_name, |
| 728 | + "-o", |
| 729 | + "jsonpath='{.status.phase}'", |
| 730 | + ).strip("'") |
| 731 | + if cluster_status: |
| 732 | + return cluster_status |
| 733 | + await asyncio.sleep(0.1) |
| 734 | + |
| 735 | + |
| 736 | +@pytest.mark.timeout(180) |
| 737 | +@pytest.mark.anyio |
| 738 | +@pytest.mark.parametrize( |
| 739 | + "cluster_name,expected_status", |
| 740 | + [ |
| 741 | + ("valid-name", "Created"), |
| 742 | + ((MAX_CLUSTER_NAME_LEN + 1) * "a", "Error"), |
| 743 | + ("invalid.chars.in.name", "Error"), |
| 744 | + ], |
| 745 | +) |
| 746 | +async def test_create_cluster_validates_name( |
| 747 | + cluster_name, expected_status, k8s_cluster, kopf_runner, gen_cluster |
| 748 | +): |
| 749 | + with kopf_runner: |
| 750 | + async with gen_cluster(cluster_name=cluster_name) as (_, ns): |
| 751 | + actual_status = await _get_cluster_status(k8s_cluster, ns, cluster_name) |
| 752 | + assert expected_status == actual_status |
0 commit comments