Skip to content

Commit 2696940

Browse files
authored
change: autofill required instance-type label (#68)
* change: autofill required instance-type label * Add validation for CLI input * Respect user provided label_selector preferred instance_type * safer assignment of label * fix assignment in validator * use node.* prefix instead of beta.* prefix in label * Update README.md
1 parent e97a318 commit 2696940

File tree

5 files changed

+290
-2
lines changed

5 files changed

+290
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ hyperpod start-job --job-name <job-name> [--namespace <namespace>] [--job-kind <
147147
* `script-args` (list[string]) - Optional. The list of arguments for entry scripts.
148148
* `environment` (dict[string, string]) - Optional. The environment variables (key-value pairs) to set in the containers.
149149
* `node-count` (int) - Required. The number of nodes (instances) to launch the jobs on.
150-
* `instance-type` (string) - Required. The instance type to launch the job on. Note that the instance types you can use are the available instances within your SageMaker quotas for instances prefixed with `ml`.
150+
* `instance-type` (string) - Required. The instance type to launch the job on. Note that the instance types you can use are the available instances within your SageMaker quotas for instances prefixed with `ml`. If `node.kubernetes.io/instance-type` is provided via the `label-selector` it will take precedence for node selection.
151151
* `tasks-per-node` (int) - Optional. The number of devices to use per instance.
152152
* `label-selector` (dict[string, list[string]]) - Optional. A dictionary of labels and their values that will override the predefined node selection rules based on the SageMaker HyperPod `node-health-status` label and values. If users provide this field, the CLI will launch the job with this customized label selection.
153153
* `deep-health-check-passed-nodes-only` (bool) - Optional. If set to `true`, the job will be launched only on nodes that have the `deep-health-check-status` label with the value `passed`.

src/hyperpod_cli/commands/job.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
HYPERPOD_KUBERNETES_JOB_PREFIX,
3333
HYPERPOD_MAX_RETRY_ANNOTATION_KEY,
3434
HYPERPOD_NAMESPACE_PREFIX,
35+
INSTANCE_TYPE_LABEL,
3536
KUEUE_JOB_UID_LABEL_KEY,
3637
KUEUE_QUEUE_NAME_LABEL_KEY,
3738
KUEUE_WORKLOAD_PRIORITY_CLASS_LABEL_KEY,
@@ -661,7 +662,7 @@ def start_job(
661662
config["cluster"]["cluster_config"]["volumes"] = volume_mount
662663

663664
if label_selector is not None:
664-
config["cluster"]["cluster_config"]["label_selector"] = label_selector
665+
config["cluster"]["cluster_config"]["label_selector"] = json.loads(label_selector)
665666
elif deep_health_check_passed_nodes_only:
666667
config["cluster"]["cluster_config"]["label_selector"] = (
667668
DEEP_HEALTH_CHECK_PASSED_ONLY_NODE_AFFINITY_DICT
@@ -671,6 +672,20 @@ def start_job(
671672
NODE_AFFINITY_DICT
672673
)
673674

675+
label_selector = config["cluster"]["cluster_config"].setdefault("label_selector",{})
676+
required_labels = label_selector.get("required", {})
677+
preferred_labels = label_selector.get("preferred", {})
678+
679+
if (
680+
not required_labels.get(INSTANCE_TYPE_LABEL) and
681+
not preferred_labels.get(INSTANCE_TYPE_LABEL)
682+
):
683+
if "required" not in label_selector:
684+
label_selector["required"] = {}
685+
label_selector["required"][INSTANCE_TYPE_LABEL] = (
686+
[str(instance_type)]
687+
)
688+
674689
if auto_resume:
675690
# Set max_retry default to 1
676691
if max_retry is None:

src/hyperpod_cli/validators/job_validator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
KUEUE_QUEUE_NAME_LABEL_KEY,
2525
HYPERPOD_AUTO_RESUME_ANNOTATION_KEY,
2626
HYPERPOD_MAX_RETRY_ANNOTATION_KEY,
27+
INSTANCE_TYPE_LABEL,
2728
SchedulerType
2829
)
2930
from hyperpod_cli.constants.hyperpod_instance_types import (
@@ -184,6 +185,20 @@ def validate_yaml_content(data):
184185
if custom_labels is not None:
185186
queue_name = custom_labels.get(KUEUE_QUEUE_NAME_LABEL_KEY, None)
186187

188+
label_selector = cluster_config_fields.setdefault("label_selector",{})
189+
required_labels = label_selector.get("required", {})
190+
preferred_labels = label_selector.get("preferred", {})
191+
192+
if (
193+
not required_labels.get(INSTANCE_TYPE_LABEL) and
194+
not preferred_labels.get(INSTANCE_TYPE_LABEL)
195+
):
196+
if "required" not in label_selector:
197+
label_selector["required"] = {}
198+
label_selector["required"][INSTANCE_TYPE_LABEL] = (
199+
[str(instance_type)]
200+
)
201+
187202
auto_resume = False
188203
max_retry = None
189204
if annotations is not None:

test/unit_tests/test_job.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,169 @@ def test_start_job_with_cli_args(
493493
print(f"Exception: {result.exception}")
494494
self.assertEqual(result.exit_code, 0)
495495

496+
@mock.patch('subprocess.run')
497+
@mock.patch("yaml.dump")
498+
@mock.patch("os.path.exists", return_value=True)
499+
@mock.patch("os.remove", return_value=None)
500+
@mock.patch("hyperpod_cli.utils.get_cluster_console_url")
501+
@mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__")
502+
@mock.patch("hyperpod_cli.commands.job.JobValidator")
503+
@mock.patch("boto3.Session")
504+
def test_start_job_default_label_selector_config(
505+
self,
506+
mock_boto3,
507+
mock_validator_cls,
508+
mock_kubernetes_client,
509+
mock_get_console_link,
510+
mock_remove,
511+
mock_exists,
512+
mock_yaml_dump,
513+
mock_subprocess_run,
514+
):
515+
# Setup mocks
516+
mock_validator = mock_validator_cls.return_value
517+
mock_validator.validate_aws_credential.return_value = True
518+
mock_kubernetes_client.get_current_context_namespace.return_value = "kubeflow"
519+
mock_get_console_link.return_value = "test-console-link"
520+
mock_subprocess_run.return_value = subprocess.CompletedProcess(
521+
args=['some_command'],
522+
returncode=0,
523+
stdout='Command executed successfully',
524+
stderr=''
525+
)
526+
527+
expected_default_label_selector_config = {
528+
"required": {
529+
"sagemaker.amazonaws.com/node-health-status": ["Schedulable"],
530+
"node.kubernetes.io/instance-type": ["ml.c5.xlarge"]
531+
},
532+
"preferred": {"sagemaker.amazonaws.com/deep-health-check-status": ["Passed"]},
533+
"weights": [100],
534+
}
535+
536+
# Capture the yaml.dump calls to inspect the config
537+
configs_dumped = []
538+
def capture_yaml_dump(config, *args, **kwargs):
539+
configs_dumped.append(config)
540+
print(f"Dumped config: {config}")
541+
return None
542+
mock_yaml_dump.side_effect = capture_yaml_dump
543+
544+
# Run the command
545+
result = self.runner.invoke(
546+
start_job,
547+
[
548+
"--job-name", "test-job",
549+
"--instance-type", "ml.c5.xlarge",
550+
"--image", "pytorch:1.9.0-cuda11.1-cudnn8-runtime",
551+
"--node-count", "2",
552+
"--entry-script", "/opt/train/src/train.py",
553+
],
554+
catch_exceptions=False
555+
)
556+
557+
# Verify the command executed successfully
558+
self.assertEqual(result.exit_code, 0)
559+
560+
# Get the config that was generated
561+
self.assertTrue(len(configs_dumped) > 0, "No config was generated")
562+
config = configs_dumped[0] # Get the first config that was dumped
563+
564+
# Verify label_selector configuration
565+
self.assertIn('cluster', config)
566+
self.assertIn('cluster_config', config['cluster'])
567+
self.assertIn('label_selector', config['cluster']['cluster_config'])
568+
569+
self.assertEqual(
570+
config['cluster']['cluster_config']['label_selector'],
571+
expected_default_label_selector_config
572+
)
573+
574+
print(f"Exit code: {result.exit_code}")
575+
print(f"Output: {result.output}")
576+
if result.exception:
577+
print(f"Exception: {result.exception}")
578+
579+
@mock.patch('subprocess.run')
580+
@mock.patch("yaml.dump")
581+
@mock.patch("os.path.exists", return_value=True)
582+
@mock.patch("os.remove", return_value=None)
583+
@mock.patch("hyperpod_cli.utils.get_cluster_console_url")
584+
@mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__")
585+
@mock.patch("hyperpod_cli.commands.job.JobValidator")
586+
@mock.patch("boto3.Session")
587+
def test_start_job_label_selector_preferred_instance_type(
588+
self,
589+
mock_boto3,
590+
mock_validator_cls,
591+
mock_kubernetes_client,
592+
mock_get_console_link,
593+
mock_remove,
594+
mock_exists,
595+
mock_yaml_dump,
596+
mock_subprocess_run,
597+
):
598+
# Setup mocks
599+
mock_validator = mock_validator_cls.return_value
600+
mock_validator.validate_aws_credential.return_value = True
601+
mock_kubernetes_client.get_current_context_namespace.return_value = "kubeflow"
602+
mock_get_console_link.return_value = "test-console-link"
603+
mock_subprocess_run.return_value = subprocess.CompletedProcess(
604+
args=['some_command'],
605+
returncode=0,
606+
stdout='Command executed successfully',
607+
stderr=''
608+
)
609+
610+
expected_default_label_selector_config = {
611+
"preferred": {"node.kubernetes.io/instance-type": ["ml.c5.xlarge"]},
612+
}
613+
614+
# Capture the yaml.dump calls to inspect the config
615+
configs_dumped = []
616+
def capture_yaml_dump(config, *args, **kwargs):
617+
configs_dumped.append(config)
618+
print(f"Dumped config: {config}")
619+
return None
620+
mock_yaml_dump.side_effect = capture_yaml_dump
621+
622+
# Run the command
623+
result = self.runner.invoke(
624+
start_job,
625+
[
626+
"--job-name", "test-job",
627+
"--instance-type", "ml.c5.xlarge",
628+
"--image", "pytorch:1.9.0-cuda11.1-cudnn8-runtime",
629+
"--node-count", "2",
630+
"--entry-script", "/opt/train/src/train.py",
631+
"--label-selector",
632+
'{"preferred": {"node.kubernetes.io/instance-type": ["ml.c5.xlarge"]}}',
633+
],
634+
catch_exceptions=False
635+
)
636+
637+
# Verify the command executed successfully
638+
self.assertEqual(result.exit_code, 0)
639+
640+
# Get the config that was generated
641+
self.assertTrue(len(configs_dumped) > 0, "No config was generated")
642+
config = configs_dumped[0] # Get the first config that was dumped
643+
644+
# Verify label_selector configuration
645+
self.assertIn('cluster', config)
646+
self.assertIn('cluster_config', config['cluster'])
647+
self.assertIn('label_selector', config['cluster']['cluster_config'])
648+
649+
self.assertEqual(
650+
config['cluster']['cluster_config']['label_selector'],
651+
expected_default_label_selector_config
652+
)
653+
654+
print(f"Exit code: {result.exit_code}")
655+
print(f"Output: {result.output}")
656+
if result.exception:
657+
print(f"Exception: {result.exception}")
658+
496659
@mock.patch('subprocess.run')
497660
@mock.patch("yaml.dump")
498661
@mock.patch("os.path.exists", return_value=True)

test/unit_tests/validators/test_job_validator.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,101 @@ def test_validate_yaml_content_valid(self):
10861086
}
10871087
result = validate_yaml_content(mock_data)
10881088
self.assertTrue(result)
1089+
1090+
def test_validate_yaml_content_preferred_instance_type_label(self):
1091+
expected_label_selector = {
1092+
"preferred": {
1093+
"node.kubernetes.io/instance-type": [
1094+
"ml.g5.xlarge"
1095+
]
1096+
}
1097+
}
1098+
1099+
# Respect user provided label_selector
1100+
mock_data = {
1101+
"cluster": {
1102+
"cluster_type": "k8s",
1103+
"instance_type": "ml.g5.xlarge",
1104+
"cluster_config": {
1105+
"scheduler": "SageMaker",
1106+
"label_selector": {
1107+
"preferred": {
1108+
"node.kubernetes.io/instance-type": [
1109+
"ml.g5.xlarge"
1110+
]
1111+
}
1112+
}
1113+
},
1114+
}
1115+
}
1116+
1117+
result = validate_yaml_content(mock_data)
1118+
self.assertTrue(result)
1119+
self.assertEqual(
1120+
mock_data["cluster"]["cluster_config"]["label_selector"], expected_label_selector
1121+
)
1122+
1123+
def test_validate_yaml_content_required_instance_type_label(self):
1124+
expected_label_selector = {
1125+
"required": {
1126+
"node.kubernetes.io/instance-type": [
1127+
"ml.g5.xlarge"
1128+
]
1129+
}
1130+
}
1131+
1132+
# User does not provide label_selector
1133+
mock_data = {
1134+
"cluster": {
1135+
"cluster_type": "k8s",
1136+
"instance_type": "ml.g5.xlarge",
1137+
"cluster_config": {
1138+
"scheduler": "SageMaker"
1139+
},
1140+
}
1141+
}
1142+
1143+
result = validate_yaml_content(mock_data)
1144+
self.assertTrue(result)
1145+
self.assertEqual(
1146+
mock_data["cluster"]["cluster_config"]["label_selector"], expected_label_selector
1147+
)
1148+
1149+
expected_label_selector = {
1150+
"required": {
1151+
"sagemaker.amazonaws.com/node-health-status": [
1152+
"Schedulable"
1153+
],
1154+
"node.kubernetes.io/instance-type": [
1155+
"ml.g5.xlarge"
1156+
]
1157+
}
1158+
}
1159+
1160+
# User provides label_selector without instance_type
1161+
mock_data = {
1162+
"cluster": {
1163+
"cluster_type": "k8s",
1164+
"instance_type": "ml.g5.xlarge",
1165+
"cluster_config": {
1166+
"scheduler": "SageMaker",
1167+
"label_selector": {
1168+
"required": {
1169+
"sagemaker.amazonaws.com/node-health-status": [
1170+
"Schedulable"
1171+
]
1172+
}
1173+
}
1174+
},
1175+
}
1176+
}
1177+
1178+
result = validate_yaml_content(mock_data)
1179+
self.assertTrue(result)
1180+
self.assertEqual(
1181+
mock_data["cluster"]["cluster_config"]["label_selector"], expected_label_selector
1182+
)
1183+
10891184

10901185
def test_validate_yaml_content_error_no_cluster(
10911186
self,

0 commit comments

Comments
 (0)