Skip to content

Commit ee1a6f2

Browse files
committed
fix(RHOAIENG-37842): fix heterogeneous oauth test
1 parent 67df4f5 commit ee1a6f2

File tree

3 files changed

+120
-14
lines changed

3 files changed

+120
-14
lines changed

tests/e2e/heterogeneous_clusters_oauth_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from time import sleep
2-
import time
31
from codeflare_sdk import (
42
Cluster,
53
ClusterConfiguration,
@@ -55,22 +53,24 @@ def run_heterogeneous_clusters(
5553
namespace=self.namespace,
5654
name=cluster_name,
5755
num_workers=1,
58-
head_cpu_requests=1,
59-
head_cpu_limits=1,
60-
worker_cpu_requests=1,
56+
head_cpu_requests="500m",
57+
head_cpu_limits="500m",
58+
head_memory_requests=2,
59+
head_memory_limits=4,
60+
worker_cpu_requests="500m",
6161
worker_cpu_limits=1,
62-
worker_memory_requests=1,
62+
worker_memory_requests=2,
6363
worker_memory_limits=4,
6464
image=ray_image,
6565
verify_tls=False,
6666
local_queue=queue_name,
6767
)
6868
)
6969
cluster.apply()
70-
sleep(5)
70+
# Wait for the cluster to be scheduled and ready, we don't need the dashboard for this check
71+
cluster.wait_ready(dashboard_check=False)
7172
node_name = get_pod_node(self, self.namespace, cluster_name)
7273
print(f"Cluster {cluster_name}-{flavor} is running on node: {node_name}")
73-
sleep(5)
7474
assert (
7575
node_name in expected_nodes
7676
), f"Node {node_name} is not in the expected nodes for flavor {flavor}."

tests/e2e/local_interactive_sdk_oauth_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from support import *
1313

1414

15+
@pytest.mark.skip(reason="Remote ray.init() is temporarily unsupported")
1516
@pytest.mark.openshift
1617
class TestRayLocalInteractiveOauth:
1718
def setup_method(self):

tests/e2e/support.py

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from time import sleep
66
from codeflare_sdk import get_cluster
77
from kubernetes import client, config
8+
from kubernetes.client import V1Toleration
89
from codeflare_sdk.common.kubernetes_cluster.kube_api_helpers import (
910
_kube_api_error_handling,
1011
)
@@ -146,6 +147,92 @@ def random_choice():
146147
return "".join(random.choices(alphabet, k=5))
147148

148149

150+
def _parse_label_env(env_var, default):
151+
"""Parse label from environment variable (format: 'key=value')."""
152+
label_str = os.getenv(env_var, default)
153+
return label_str.split("=")
154+
155+
156+
def get_master_taint_key(self):
157+
"""
158+
Detect the actual master/control-plane taint key from nodes.
159+
Returns the taint key if found, or defaults to control-plane.
160+
"""
161+
# Check env var first (most efficient)
162+
if os.getenv("TOLERATION_KEY"):
163+
return os.getenv("TOLERATION_KEY")
164+
165+
# Try to detect from cluster nodes
166+
try:
167+
nodes = self.api_instance.list_node()
168+
taint_key = next(
169+
(
170+
taint.key
171+
for node in nodes.items
172+
if node.spec.taints
173+
for taint in node.spec.taints
174+
if taint.key
175+
in [
176+
"node-role.kubernetes.io/master",
177+
"node-role.kubernetes.io/control-plane",
178+
]
179+
),
180+
None,
181+
)
182+
if taint_key:
183+
return taint_key
184+
except Exception as e:
185+
print(f"Warning: Could not detect master taint key: {e}")
186+
187+
# Default fallback
188+
return "node-role.kubernetes.io/control-plane"
189+
190+
191+
def ensure_nodes_labeled_for_flavors(self, num_flavors, with_labels):
192+
"""
193+
Check if required node labels exist for ResourceFlavor targeting.
194+
This handles both default (worker-1=true) and non-default (ingress-ready=true) flavors.
195+
196+
NOTE: This function does NOT modify cluster nodes. It only checks if required labels exist.
197+
If labels don't exist, the test will use whatever labels are available on the cluster.
198+
For shared clusters, set WORKER_LABEL and CONTROL_LABEL env vars to match existing labels.
199+
"""
200+
if not with_labels:
201+
return
202+
203+
worker_label, worker_value = _parse_label_env("WORKER_LABEL", "worker-1=true")
204+
control_label, control_value = _parse_label_env(
205+
"CONTROL_LABEL", "ingress-ready=true"
206+
)
207+
208+
try:
209+
worker_nodes = self.api_instance.list_node(
210+
label_selector="node-role.kubernetes.io/worker"
211+
)
212+
213+
if not worker_nodes.items:
214+
print("Warning: No worker nodes found")
215+
return
216+
217+
# Check labels based on num_flavors
218+
labels_to_check = [("WORKER_LABEL", worker_label, worker_value)]
219+
if num_flavors > 1:
220+
labels_to_check.append(("CONTROL_LABEL", control_label, control_value))
221+
222+
for env_var, label, value in labels_to_check:
223+
has_label = any(
224+
node.metadata.labels and node.metadata.labels.get(label) == value
225+
for node in worker_nodes.items
226+
)
227+
if not has_label:
228+
print(
229+
f"Warning: Label {label}={value} not found (set {env_var} env var to match existing labels)"
230+
)
231+
232+
except Exception as e:
233+
print(f"Warning: Could not check existing labels: {e}")
234+
235+
149236
def create_namespace(self):
150237
try:
151238
self.namespace = f"test-ns-{random_choice()}"
@@ -280,14 +367,13 @@ def create_cluster_queue(self, cluster_queue, flavor):
280367
def create_resource_flavor(
281368
self, flavor, default=True, with_labels=False, with_tolerations=False
282369
):
283-
worker_label, worker_value = os.getenv("WORKER_LABEL", "worker-1=true").split("=")
284-
control_label, control_value = os.getenv(
370+
worker_label, worker_value = _parse_label_env("WORKER_LABEL", "worker-1=true")
371+
control_label, control_value = _parse_label_env(
285372
"CONTROL_LABEL", "ingress-ready=true"
286-
).split("=")
287-
toleration_key = os.getenv(
288-
"TOLERATION_KEY", "node-role.kubernetes.io/control-plane"
289373
)
290374

375+
toleration_key = os.getenv("TOLERATION_KEY") or get_master_taint_key(self)
376+
291377
node_labels = {}
292378
if with_labels:
293379
node_labels = (
@@ -451,6 +537,25 @@ def get_nodes_by_label(self, node_labels):
451537
return [node.metadata.name for node in nodes.items]
452538

453539

540+
def get_tolerations_from_flavor(self, flavor_name):
541+
"""
542+
Extract tolerations from a ResourceFlavor and convert them to V1Toleration objects.
543+
Returns a list of V1Toleration objects, or empty list if no tolerations found.
544+
"""
545+
flavor_spec = get_flavor_spec(self, flavor_name)
546+
tolerations_spec = flavor_spec.get("spec", {}).get("tolerations", [])
547+
548+
return [
549+
V1Toleration(
550+
key=tol_spec.get("key"),
551+
operator=tol_spec.get("operator", "Equal"),
552+
value=tol_spec.get("value"),
553+
effect=tol_spec.get("effect"),
554+
)
555+
for tol_spec in tolerations_spec
556+
]
557+
558+
454559
def assert_get_cluster_and_jobsubmit(
455560
self, cluster_name, accelerator=None, number_of_gpus=None
456561
):
@@ -514,7 +619,7 @@ def wait_for_kueue_admission(self, job_api, job_name, namespace, timeout=120):
514619
workload = get_kueue_workload_for_job(self, job_name, namespace)
515620
if workload:
516621
conditions = workload.get("status", {}).get("conditions", [])
517-
print(f" DEBUG: Workload conditions for '{job_name}':")
622+
print(f"DEBUG: Workload conditions for '{job_name}':")
518623
for condition in conditions:
519624
print(
520625
f" - {condition.get('type')}: {condition.get('status')} - {condition.get('reason', '')} - {condition.get('message', '')}"

0 commit comments

Comments
 (0)