Skip to content

Commit 4fe9a37

Browse files
committed
change: autofill required instance-type label
1 parent e97a318 commit 4fe9a37

File tree

6 files changed

+225
-1
lines changed

6 files changed

+225
-1
lines changed

src/hyperpod_cli/commands/job.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
HYPERPOD_KUBERNETES_JOB_PREFIX,
3333
HYPERPOD_MAX_RETRY_ANNOTATION_KEY,
3434
HYPERPOD_NAMESPACE_PREFIX,
35+
KUBERNETES_INSTANCE_TYPE_LABEL_KEY,
3536
KUEUE_JOB_UID_LABEL_KEY,
3637
KUEUE_QUEUE_NAME_LABEL_KEY,
3738
KUEUE_WORKLOAD_PRIORITY_CLASS_LABEL_KEY,
@@ -671,6 +672,15 @@ def start_job(
671672
NODE_AFFINITY_DICT
672673
)
673674

675+
label_selector = config["cluster"].setdefault("label_selector",{})
676+
required_labels = label_selector.setdefault("required", {})
677+
678+
if not required_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY):
679+
required_labels[KUBERNETES_INSTANCE_TYPE_LABEL_KEY] = (
680+
[str(instance_type)]
681+
)
682+
683+
674684
if auto_resume:
675685
# Set max_retry default to 1
676686
if max_retry is None:

src/hyperpod_cli/constants/command_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"sagemaker.amazonaws.com/deep-health-check-status": ["Passed"],
2727
},
2828
}
29+
KUBERNETES_INSTANCE_TYPE_LABEL_KEY = "beta.kubernetes.io/instance-type"
2930
KUEUE_QUEUE_NAME_LABEL_KEY = "kueue.x-k8s.io/queue-name"
3031
KUEUE_WORKLOAD_PRIORITY_CLASS_LABEL_KEY = "kueue.x-k8s.io/priority-class"
3132
KUEUE_JOB_UID_LABEL_KEY = "kueue.x-k8s.io/job-uid"

src/hyperpod_cli/validators/job_validator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
KUEUE_QUEUE_NAME_LABEL_KEY,
2525
HYPERPOD_AUTO_RESUME_ANNOTATION_KEY,
2626
HYPERPOD_MAX_RETRY_ANNOTATION_KEY,
27+
KUBERNETES_INSTANCE_TYPE_LABEL_KEY,
2728
SchedulerType
2829
)
2930
from hyperpod_cli.constants.hyperpod_instance_types import (
@@ -184,6 +185,19 @@ def validate_yaml_content(data):
184185
if custom_labels is not None:
185186
queue_name = custom_labels.get(KUEUE_QUEUE_NAME_LABEL_KEY, None)
186187

188+
label_selector = cluster_config_fields.setdefault("label_selector",{})
189+
required_labels = label_selector.setdefault("required", {})
190+
191+
if not required_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY):
192+
required_labels[KUBERNETES_INSTANCE_TYPE_LABEL_KEY] = (
193+
[str(instance_type)]
194+
)
195+
if instance_type not in required_labels[KUBERNETES_INSTANCE_TYPE_LABEL_KEY]:
196+
logger.error(
197+
f"Please ensure 'instance-type' in 'cluster' matches with 'instance-type' in 'label_selector.required.beta.kubernetes.io/instance-type' in config file"
198+
)
199+
return False
200+
187201
auto_resume = False
188202
max_retry = None
189203
if annotations is not None:

test/integration_tests/data/basicJob.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ training_cfg:
2222
entry_script: /opt/pytorch-mnist/mnist.py
2323
script_args: []
2424
run:
25-
name: ${JOB_NAME} # Current run name
25+
name: hyperpod-benieric-test # Current run name
2626
nodes: 1 # Number of nodes to use for current training
2727
ntasks_per_node: 1 # Number of devices to use per node
2828
cluster:
@@ -40,6 +40,8 @@ cluster:
4040
required:
4141
sagemaker.amazonaws.com/node-health-status:
4242
- Schedulable
43+
beta.kubernetes.io/instance-type:
44+
- ml.c5.4xlarge
4345
preferred:
4446
sagemaker.amazonaws.com/deep-health-check-status:
4547
- Passed

test/unit_tests/test_job.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,90 @@ def test_start_job_with_cli_args(
493493
print(f"Exception: {result.exception}")
494494
self.assertEqual(result.exit_code, 0)
495495

496+
@mock.patch('subprocess.run')
497+
@mock.patch("yaml.dump")
498+
@mock.patch("os.path.exists", return_value=True)
499+
@mock.patch("os.remove", return_value=None)
500+
@mock.patch("hyperpod_cli.utils.get_cluster_console_url")
501+
@mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__")
502+
@mock.patch("hyperpod_cli.commands.job.JobValidator")
503+
@mock.patch("boto3.Session")
504+
def test_start_job_default_label_selector_config(
505+
self,
506+
mock_boto3,
507+
mock_validator_cls,
508+
mock_kubernetes_client,
509+
mock_get_console_link,
510+
mock_remove,
511+
mock_exists,
512+
mock_yaml_dump,
513+
mock_subprocess_run,
514+
):
515+
# Setup mocks
516+
mock_validator = mock_validator_cls.return_value
517+
mock_validator.validate_aws_credential.return_value = True
518+
mock_kubernetes_client.get_current_context_namespace.return_value = "kubeflow"
519+
mock_get_console_link.return_value = "test-console-link"
520+
mock_subprocess_run.return_value = subprocess.CompletedProcess(
521+
args=['some_command'],
522+
returncode=0,
523+
stdout='Command executed successfully',
524+
stderr=''
525+
)
526+
527+
expected_default_label_selector_config = {
528+
"required": {
529+
"sagemaker.amazonaws.com/node-health-status": ["Schedulable"],
530+
"beta.kubernetes.io/instance-type": ["ml.c5.xlarge"]
531+
},
532+
"preferred": {"sagemaker.amazonaws.com/deep-health-check-status": ["Passed"]},
533+
"weights": [100],
534+
}
535+
536+
# Capture the yaml.dump calls to inspect the config
537+
configs_dumped = []
538+
def capture_yaml_dump(config, *args, **kwargs):
539+
configs_dumped.append(config)
540+
print(f"Dumped config: {config}")
541+
return None
542+
mock_yaml_dump.side_effect = capture_yaml_dump
543+
544+
# Run the command
545+
result = self.runner.invoke(
546+
start_job,
547+
[
548+
"--job-name", "test-job",
549+
"--instance-type", "ml.c5.xlarge",
550+
"--image", "pytorch:1.9.0-cuda11.1-cudnn8-runtime",
551+
"--node-count", "2",
552+
"--entry-script", "/opt/train/src/train.py",
553+
],
554+
catch_exceptions=False
555+
)
556+
557+
# Verify the command executed successfully
558+
self.assertEqual(result.exit_code, 0)
559+
560+
# Get the config that was generated
561+
self.assertTrue(len(configs_dumped) > 0, "No config was generated")
562+
config = configs_dumped[0] # Get the first config that was dumped
563+
564+
# Verify label_selector configuration
565+
self.assertIn('cluster', config)
566+
self.assertIn('cluster_config', config['cluster'])
567+
self.assertIn('label_selector', config['cluster']['cluster_config'])
568+
569+
self.assertEqual(
570+
config['cluster']['cluster_config']['label_selector'],
571+
expected_default_label_selector_config
572+
)
573+
574+
print(f"Exit code: {result.exit_code}")
575+
print(f"Output: {result.output}")
576+
if result.exception:
577+
print(f"Exception: {result.exception}")
578+
579+
496580
@mock.patch('subprocess.run')
497581
@mock.patch("yaml.dump")
498582
@mock.patch("os.path.exists", return_value=True)

test/unit_tests/validators/test_job_validator.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,119 @@ def test_validate_yaml_content_valid(self):
10861086
}
10871087
result = validate_yaml_content(mock_data)
10881088
self.assertTrue(result)
1089+
1090+
def test_validate_yaml_content_valid_required_instance_type_label(self):
1091+
# Test auto-populate label selector - "beta.kubernetes.io/instance_type"
1092+
expected_label_selector = {
1093+
"required": {
1094+
"beta.kubernetes.io/instance-type": [
1095+
"ml.g5.xlarge"
1096+
]
1097+
}
1098+
}
1099+
1100+
mock_data = {
1101+
"cluster": {
1102+
"cluster_type": "k8s",
1103+
"instance_type": "ml.g5.xlarge",
1104+
"cluster_config": {
1105+
"scheduler": "SageMaker"
1106+
},
1107+
}
1108+
}
1109+
1110+
result = validate_yaml_content(mock_data)
1111+
self.assertTrue(result)
1112+
self.assertEqual(
1113+
mock_data["cluster"]["cluster_config"]["label_selector"], expected_label_selector
1114+
)
1115+
1116+
expected_label_selector = {
1117+
"required": {
1118+
"sagemaker.amazonaws.com/node-health-status": [
1119+
"Schedulable"
1120+
],
1121+
"beta.kubernetes.io/instance-type": [
1122+
"ml.g5.xlarge"
1123+
]
1124+
}
1125+
}
1126+
1127+
mock_data = {
1128+
"cluster": {
1129+
"cluster_type": "k8s",
1130+
"instance_type": "ml.g5.xlarge",
1131+
"cluster_config": {
1132+
"scheduler": "SageMaker",
1133+
"label_selector": {
1134+
"required": {
1135+
"sagemaker.amazonaws.com/node-health-status": [
1136+
"Schedulable"
1137+
]
1138+
}
1139+
}
1140+
},
1141+
}
1142+
}
1143+
1144+
result = validate_yaml_content(mock_data)
1145+
self.assertTrue(result)
1146+
self.assertEqual(
1147+
mock_data["cluster"]["cluster_config"]["label_selector"], expected_label_selector
1148+
)
1149+
1150+
expected_label_selector = {
1151+
"required": {
1152+
"beta.kubernetes.io/instance-type": [
1153+
"ml.g5.xlarge"
1154+
]
1155+
}
1156+
}
1157+
1158+
mock_data = {
1159+
"cluster": {
1160+
"cluster_type": "k8s",
1161+
"instance_type": "ml.g5.xlarge",
1162+
"cluster_config": {
1163+
"scheduler": "SageMaker",
1164+
"label_selector": {
1165+
"required": {
1166+
"beta.kubernetes.io/instance-type": [
1167+
"ml.g5.xlarge"
1168+
]
1169+
}
1170+
}
1171+
}
1172+
}
1173+
}
1174+
1175+
result = validate_yaml_content(mock_data)
1176+
self.assertTrue(result)
1177+
self.assertEqual(
1178+
mock_data["cluster"]["cluster_config"]["label_selector"], expected_label_selector
1179+
)
1180+
1181+
def test_validate_yaml_content_invalid_required_instance_type_label(self):
1182+
# Test resepect user selection label selector
1183+
mock_data = {
1184+
"cluster": {
1185+
"cluster_type": "k8s",
1186+
"instance_type": "ml.g5.xlarge",
1187+
"cluster_config": {
1188+
"scheduler": "SageMaker",
1189+
"label_selector": {
1190+
"required": {
1191+
"beta.kubernetes.io/instance-type": [
1192+
"ml.g5.2xlarge"
1193+
]
1194+
}
1195+
}
1196+
}
1197+
}
1198+
}
1199+
1200+
result = validate_yaml_content(mock_data)
1201+
self.assertFalse(result)
10891202

10901203
def test_validate_yaml_content_error_no_cluster(
10911204
self,

0 commit comments

Comments
 (0)