|
5 | 5 | from time import sleep |
6 | 6 | from codeflare_sdk import get_cluster |
7 | 7 | from kubernetes import client, config |
| 8 | +from kubernetes.client import V1Toleration |
8 | 9 | from codeflare_sdk.common.kubernetes_cluster.kube_api_helpers import ( |
9 | 10 | _kube_api_error_handling, |
10 | 11 | ) |
@@ -146,6 +147,92 @@ def random_choice(): |
146 | 147 | return "".join(random.choices(alphabet, k=5)) |
147 | 148 |
|
148 | 149 |
|
| 150 | +def _parse_label_env(env_var, default): |
| 151 | + """Parse label from environment variable (format: 'key=value').""" |
| 152 | + label_str = os.getenv(env_var, default) |
| 153 | + return label_str.split("=") |
| 154 | + |
| 155 | + |
| 156 | +def get_master_taint_key(self): |
| 157 | + """ |
| 158 | + Detect the actual master/control-plane taint key from nodes. |
| 159 | + Returns the taint key if found, or defaults to control-plane. |
| 160 | + """ |
| 161 | + # Check env var first (most efficient) |
| 162 | + if os.getenv("TOLERATION_KEY"): |
| 163 | + return os.getenv("TOLERATION_KEY") |
| 164 | + |
| 165 | + # Try to detect from cluster nodes |
| 166 | + try: |
| 167 | + nodes = self.api_instance.list_node() |
| 168 | + taint_key = next( |
| 169 | + ( |
| 170 | + taint.key |
| 171 | + for node in nodes.items |
| 172 | + if node.spec.taints |
| 173 | + for taint in node.spec.taints |
| 174 | + if taint.key |
| 175 | + in [ |
| 176 | + "node-role.kubernetes.io/master", |
| 177 | + "node-role.kubernetes.io/control-plane", |
| 178 | + ] |
| 179 | + ), |
| 180 | + None, |
| 181 | + ) |
| 182 | + if taint_key: |
| 183 | + return taint_key |
| 184 | + except Exception as e: |
| 185 | + print(f"Warning: Could not detect master taint key: {e}") |
| 186 | + |
| 187 | + # Default fallback |
| 188 | + return "node-role.kubernetes.io/control-plane" |
| 189 | + |
| 190 | + |
| 191 | +def ensure_nodes_labeled_for_flavors(self, num_flavors, with_labels): |
| 192 | + """ |
| 193 | + Check if required node labels exist for ResourceFlavor targeting. |
| 194 | + This handles both default (worker-1=true) and non-default (ingress-ready=true) flavors. |
| 195 | +
|
| 196 | + NOTE: This function does NOT modify cluster nodes. It only checks if required labels exist. |
| 197 | + If labels don't exist, the test will use whatever labels are available on the cluster. |
| 198 | + For shared clusters, set WORKER_LABEL and CONTROL_LABEL env vars to match existing labels. |
| 199 | + """ |
| 200 | + if not with_labels: |
| 201 | + return |
| 202 | + |
| 203 | + worker_label, worker_value = _parse_label_env("WORKER_LABEL", "worker-1=true") |
| 204 | + control_label, control_value = _parse_label_env( |
| 205 | + "CONTROL_LABEL", "ingress-ready=true" |
| 206 | + ) |
| 207 | + |
| 208 | + try: |
| 209 | + worker_nodes = self.api_instance.list_node( |
| 210 | + label_selector="node-role.kubernetes.io/worker" |
| 211 | + ) |
| 212 | + |
| 213 | + if not worker_nodes.items: |
| 214 | + print("Warning: No worker nodes found") |
| 215 | + return |
| 216 | + |
| 217 | + # Check labels based on num_flavors |
| 218 | + labels_to_check = [("WORKER_LABEL", worker_label, worker_value)] |
| 219 | + if num_flavors > 1: |
| 220 | + labels_to_check.append(("CONTROL_LABEL", control_label, control_value)) |
| 221 | + |
| 222 | + for env_var, label, value in labels_to_check: |
| 223 | + has_label = any( |
| 224 | + node.metadata.labels and node.metadata.labels.get(label) == value |
| 225 | + for node in worker_nodes.items |
| 226 | + ) |
| 227 | + if not has_label: |
| 228 | + print( |
| 229 | + f"Warning: Label {label}={value} not found (set {env_var} env var to match existing labels)" |
| 230 | + ) |
| 231 | + |
| 232 | + except Exception as e: |
| 233 | + print(f"Warning: Could not check existing labels: {e}") |
| 234 | + |
| 235 | + |
149 | 236 | def create_namespace(self): |
150 | 237 | try: |
151 | 238 | self.namespace = f"test-ns-{random_choice()}" |
@@ -280,14 +367,13 @@ def create_cluster_queue(self, cluster_queue, flavor): |
280 | 367 | def create_resource_flavor( |
281 | 368 | self, flavor, default=True, with_labels=False, with_tolerations=False |
282 | 369 | ): |
283 | | - worker_label, worker_value = os.getenv("WORKER_LABEL", "worker-1=true").split("=") |
284 | | - control_label, control_value = os.getenv( |
| 370 | + worker_label, worker_value = _parse_label_env("WORKER_LABEL", "worker-1=true") |
| 371 | + control_label, control_value = _parse_label_env( |
285 | 372 | "CONTROL_LABEL", "ingress-ready=true" |
286 | | - ).split("=") |
287 | | - toleration_key = os.getenv( |
288 | | - "TOLERATION_KEY", "node-role.kubernetes.io/control-plane" |
289 | 373 | ) |
290 | 374 |
|
| 375 | + toleration_key = os.getenv("TOLERATION_KEY") or get_master_taint_key(self) |
| 376 | + |
291 | 377 | node_labels = {} |
292 | 378 | if with_labels: |
293 | 379 | node_labels = ( |
@@ -451,6 +537,25 @@ def get_nodes_by_label(self, node_labels): |
451 | 537 | return [node.metadata.name for node in nodes.items] |
452 | 538 |
|
453 | 539 |
|
| 540 | +def get_tolerations_from_flavor(self, flavor_name): |
| 541 | + """ |
| 542 | + Extract tolerations from a ResourceFlavor and convert them to V1Toleration objects. |
| 543 | + Returns a list of V1Toleration objects, or empty list if no tolerations found. |
| 544 | + """ |
| 545 | + flavor_spec = get_flavor_spec(self, flavor_name) |
| 546 | + tolerations_spec = flavor_spec.get("spec", {}).get("tolerations", []) |
| 547 | + |
| 548 | + return [ |
| 549 | + V1Toleration( |
| 550 | + key=tol_spec.get("key"), |
| 551 | + operator=tol_spec.get("operator", "Equal"), |
| 552 | + value=tol_spec.get("value"), |
| 553 | + effect=tol_spec.get("effect"), |
| 554 | + ) |
| 555 | + for tol_spec in tolerations_spec |
| 556 | + ] |
| 557 | + |
| 558 | + |
454 | 559 | def assert_get_cluster_and_jobsubmit( |
455 | 560 | self, cluster_name, accelerator=None, number_of_gpus=None |
456 | 561 | ): |
@@ -514,7 +619,7 @@ def wait_for_kueue_admission(self, job_api, job_name, namespace, timeout=120): |
514 | 619 | workload = get_kueue_workload_for_job(self, job_name, namespace) |
515 | 620 | if workload: |
516 | 621 | conditions = workload.get("status", {}).get("conditions", []) |
517 | | - print(f" DEBUG: Workload conditions for '{job_name}':") |
| 622 | + print(f"DEBUG: Workload conditions for '{job_name}':") |
518 | 623 | for condition in conditions: |
519 | 624 | print( |
520 | 625 | f" - {condition.get('type')}: {condition.get('status')} - {condition.get('reason', '')} - {condition.get('message', '')}" |
|
0 commit comments