Skip to content

Commit d8db1fc

Browse files
committed
safer assignment of label
1 parent 5181c05 commit d8db1fc

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

src/hyperpod_cli/commands/job.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -673,14 +673,16 @@ def start_job(
673673
)
674674

675675
label_selector = config["cluster"]["cluster_config"].setdefault("label_selector",{})
676-
required_labels = label_selector.setdefault("required", {})
677-
preferred_labels = label_selector.setdefault("preferred", {})
676+
required_labels = label_selector.get("required", {})
677+
preferred_labels = label_selector.get("preferred", {})
678678

679679
if (
680680
not required_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY) and
681681
not preferred_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY)
682682
):
683-
required_labels[KUBERNETES_INSTANCE_TYPE_LABEL_KEY] = (
683+
if "required" not in label_selector:
684+
label_selector["required"] = {}
685+
label_selector["required"][KUBERNETES_INSTANCE_TYPE_LABEL_KEY] = (
684686
[str(instance_type)]
685687
)
686688

src/hyperpod_cli/validators/job_validator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,15 @@ def validate_yaml_content(data):
186186
queue_name = custom_labels.get(KUEUE_QUEUE_NAME_LABEL_KEY, None)
187187

188188
label_selector = cluster_config_fields.setdefault("label_selector",{})
189-
required_labels = label_selector.setdefault("required", {})
190-
preferred_labels = label_selector.setdefault("preferred", {})
189+
required_labels = label_selector.get("required", {})
190+
preferred_labels = label_selector.get("preferred", {})
191191

192192
if (
193193
not required_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY) and
194194
not preferred_labels.get(KUBERNETES_INSTANCE_TYPE_LABEL_KEY)
195195
):
196+
if "required" not in label_selector:
197+
label_selector["required"] = {}
196198
required_labels[KUBERNETES_INSTANCE_TYPE_LABEL_KEY] = (
197199
[str(instance_type)]
198200
)

test/unit_tests/test_job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def capture_yaml_dump(config, *args, **kwargs):
628628
"--image", "pytorch:1.9.0-cuda11.1-cudnn8-runtime",
629629
"--node-count", "2",
630630
"--entry-script", "/opt/train/src/train.py",
631-
"--label_selector",
631+
"--label-selector",
632632
'{"preferred": {"beta.kubernetes.io/instance-type": ["ml.c5.xlarge"]}}',
633633
],
634634
catch_exceptions=False

0 commit comments

Comments
 (0)