Skip to content

Commit da2df2f

Browse files
authored
Fix Slurm failures from missing orchestration key (#268)
* slurm-eks-helper-fix * Small fix to test to reflect new changes
1 parent 0b1bc8f commit da2df2f

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

src/sagemaker/hyperpod/cli/commands/cluster.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,11 @@ def timeout_handler(signum, frame):
587587
sm_client = get_sagemaker_client(session, botocore_config)
588588
hp_cluster_details = sm_client.describe_cluster(ClusterName=cluster_name)
589589
logger.debug("Fetched hyperpod cluster details")
590+
591+
# Check if cluster is EKS-orchestrated
592+
if "Orchestrator" not in hp_cluster_details or "Eks" not in hp_cluster_details.get("Orchestrator", {}):
593+
raise ValueError(f"Cluster '{cluster_name}' is not EKS-orchestrated. HyperPod CLI only supports EKS-orchestrated clusters.")
594+
590595
store_current_hyperpod_context(hp_cluster_details)
591596
eks_cluster_arn = hp_cluster_details["Orchestrator"]["Eks"]["ClusterArn"]
592597
logger.debug(

src/sagemaker/hyperpod/common/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def setup_logging(logger, debug=False):
159159

160160
def is_eks_orchestrator(sagemaker_client, cluster_name: str):
161161
response = sagemaker_client.describe_cluster(ClusterName=cluster_name)
162-
return "Eks" in response["Orchestrator"]
162+
return response.get("Orchestrator", {}).get("Eks") is not None
163163

164164

165165
def update_kube_config(
@@ -250,6 +250,9 @@ def set_cluster_context(
250250

251251
client = boto3.client("sagemaker", region_name=region)
252252

253+
if not is_eks_orchestrator(client, cluster_name):
254+
raise ValueError(f"Cluster '{cluster_name}' is not EKS-orchestrated. HyperPod CLI only supports EKS-orchestrated clusters.")
255+
253256
response = client.describe_cluster(ClusterName=cluster_name)
254257
eks_cluster_arn = response["Orchestrator"]["Eks"]["ClusterArn"]
255258
eks_name = get_eks_name_from_arn(eks_cluster_arn)
@@ -300,6 +303,8 @@ def get_current_cluster():
300303
client = boto3.client("sagemaker", region_name=region)
301304

302305
for cluster_name in hyperpod_clusters:
306+
if not is_eks_orchestrator(client, cluster_name):
307+
continue
303308
response = client.describe_cluster(ClusterName=cluster_name)
304309
if response["Orchestrator"]["Eks"]["ClusterArn"] == current_context:
305310
return cluster_name

test/unit_tests/common/test_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,12 @@ def test_set_cluster_context(self, mock_set_context_func, mock_update_config, mo
389389

390390
set_cluster_context("my-cluster", "us-west-2", "test-namespace")
391391

392-
mock_client.describe_cluster.assert_called_once_with(ClusterName="my-cluster")
392+
# Expect 2 calls: one for is_eks_orchestrator validation, one for getting cluster details
393+
self.assertEqual(mock_client.describe_cluster.call_count, 2)
394+
mock_client.describe_cluster.assert_has_calls([
395+
call(ClusterName="my-cluster"),
396+
call(ClusterName="my-cluster")
397+
])
393398
mock_get_name.assert_called_once()
394399
mock_update_config.assert_called_once()
395400
mock_set_context_func.assert_called_once()

0 commit comments

Comments
 (0)