Skip to content

Commit 5181c05

Browse files
committed
Respect user provided label_selector preferred instance_type
1 parent 2da78d8 commit 5181c05

File tree

4 files changed

+123
-2
lines changed

4 files changed

+123
-2
lines changed

src/hyperpod_cli/commands/job.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,12 @@ def start_job(
674674

675675
label_selector = config["cluster"]["cluster_config"].setdefault("label_selector",{})
676676
required_labels = label_selector.setdefault("required", {})
677+
preferred_labels = label_selector.setdefault("preferred", {})
677678

678-
if not required_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY):
679+
if (
680+
not required_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY) and
681+
not preferred_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY)
682+
):
679683
required_labels[KUBERNETES_INSTANCE_TYPE_LABEL_KEY] = (
680684
[str(instance_type)]
681685
)

src/hyperpod_cli/validators/job_validator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,12 @@ def validate_yaml_content(data):
187187

188188
label_selector = cluster_config_fields.setdefault("label_selector",{})
189189
required_labels = label_selector.setdefault("required", {})
190+
preferred_labels = label_selector.setdefault("preferred", {})
190191

191-
if not required_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY):
192+
if (
193+
not required_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY) and
194+
not preferred_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY)
195+
):
192196
required_labels[KUBERNETES_INSTANCE_TYPE_LABEL_KEY] = (
193197
[str(instance_type)]
194198
)

test/unit_tests/test_job.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,86 @@ def capture_yaml_dump(config, *args, **kwargs):
575575
print(f"Output: {result.output}")
576576
if result.exception:
577577
print(f"Exception: {result.exception}")
578+
579+
@mock.patch('subprocess.run')
580+
@mock.patch("yaml.dump")
581+
@mock.patch("os.path.exists", return_value=True)
582+
@mock.patch("os.remove", return_value=None)
583+
@mock.patch("hyperpod_cli.utils.get_cluster_console_url")
584+
@mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__")
585+
@mock.patch("hyperpod_cli.commands.job.JobValidator")
586+
@mock.patch("boto3.Session")
587+
def test_start_job_label_selector_preferred_instance_type(
588+
self,
589+
mock_boto3,
590+
mock_validator_cls,
591+
mock_kubernetes_client,
592+
mock_get_console_link,
593+
mock_remove,
594+
mock_exists,
595+
mock_yaml_dump,
596+
mock_subprocess_run,
597+
):
598+
# Setup mocks
599+
mock_validator = mock_validator_cls.return_value
600+
mock_validator.validate_aws_credential.return_value = True
601+
mock_kubernetes_client.get_current_context_namespace.return_value = "kubeflow"
602+
mock_get_console_link.return_value = "test-console-link"
603+
mock_subprocess_run.return_value = subprocess.CompletedProcess(
604+
args=['some_command'],
605+
returncode=0,
606+
stdout='Command executed successfully',
607+
stderr=''
608+
)
609+
610+
expected_default_label_selector_config = {
611+
"preferred": {"beta.kubernetes.io/instance-type": ["ml.c5.xlarge"]},
612+
}
613+
614+
# Capture the yaml.dump calls to inspect the config
615+
configs_dumped = []
616+
def capture_yaml_dump(config, *args, **kwargs):
617+
configs_dumped.append(config)
618+
print(f"Dumped config: {config}")
619+
return None
620+
mock_yaml_dump.side_effect = capture_yaml_dump
621+
622+
# Run the command
623+
result = self.runner.invoke(
624+
start_job,
625+
[
626+
"--job-name", "test-job",
627+
"--instance-type", "ml.c5.xlarge",
628+
"--image", "pytorch:1.9.0-cuda11.1-cudnn8-runtime",
629+
"--node-count", "2",
630+
"--entry-script", "/opt/train/src/train.py",
631+
"--label_selector",
632+
'{"preferred": {"beta.kubernetes.io/instance-type": ["ml.c5.xlarge"]}}',
633+
],
634+
catch_exceptions=False
635+
)
636+
637+
# Verify the command executed successfully
638+
self.assertEqual(result.exit_code, 0)
639+
640+
# Get the config that was generated
641+
self.assertTrue(len(configs_dumped) > 0, "No config was generated")
642+
config = configs_dumped[0] # Get the first config that was dumped
643+
644+
# Verify label_selector configuration
645+
self.assertIn('cluster', config)
646+
self.assertIn('cluster_config', config['cluster'])
647+
self.assertIn('label_selector', config['cluster']['cluster_config'])
648+
649+
self.assertEqual(
650+
config['cluster']['cluster_config']['label_selector'],
651+
expected_default_label_selector_config
652+
)
653+
654+
print(f"Exit code: {result.exit_code}")
655+
print(f"Output: {result.output}")
656+
if result.exception:
657+
print(f"Exception: {result.exception}")
578658

579659
@mock.patch('subprocess.run')
580660
@mock.patch("yaml.dump")

test/unit_tests/validators/test_job_validator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,39 @@ def test_validate_yaml_content_valid(self):
10861086
}
10871087
result = validate_yaml_content(mock_data)
10881088
self.assertTrue(result)
1089+
1090+
def test_validate_yaml_content_preferred_instance_type_label(self):
1091+
expected_label_selector = {
1092+
"preferred": {
1093+
"beta.kubernetes.io/instance-type": [
1094+
"ml.g5.xlarge"
1095+
]
1096+
}
1097+
}
1098+
1099+
# Respect user provided label_selector
1100+
mock_data = {
1101+
"cluster": {
1102+
"cluster_type": "k8s",
1103+
"instance_type": "ml.g5.xlarge",
1104+
"cluster_config": {
1105+
"scheduler": "SageMaker",
1106+
"label_selector": {
1107+
"preferred": {
1108+
"beta.kubernetes.io/instance-type": [
1109+
"ml.g5.xlarge"
1110+
]
1111+
}
1112+
}
1113+
},
1114+
}
1115+
}
1116+
1117+
result = validate_yaml_content(mock_data)
1118+
self.assertTrue(result)
1119+
self.assertEqual(
1120+
mock_data["cluster"]["cluster_config"]["label_selector"], expected_label_selector
1121+
)
10891122

10901123
def test_validate_yaml_content_required_instance_type_label(self):
10911124
expected_label_selector = {

0 commit comments

Comments
 (0)