Skip to content

Commit b9cb87e

Browse files
Add provision to provide labels and annotations for the pytorchjob an… (kubeflow#2612)
Signed-off-by: Abhijeet Dhumal <abhijeetdhumal652@gmail.com>
1 parent f58e893 commit b9cb87e

File tree

3 files changed

+88
-5
lines changed

3 files changed

+88
-5
lines changed

sdk/python/kubeflow/training/api/training_client.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def train(
9696
self,
9797
name: str,
9898
namespace: Optional[str] = None,
99+
labels: Optional[Dict[str, str]] = None,
100+
annotations: Optional[Dict[str, str]] = None,
99101
num_workers: int = 1,
100102
num_procs_per_worker: int = 1,
101103
resources_per_worker: Union[dict, client.V1ResourceRequirements, None] = None,
@@ -133,6 +135,10 @@ def train(
133135
name: Name of the PyTorchJob.
134136
namespace: Namespace for the PyTorchJob. By default namespace is taken from
135137
`TrainingClient` object.
138+
labels: Labels for the PyTorchJob. You can specify a dictionary as a mapping object
139+
representing the labels.
140+
annotations: Annotations for the PyTorchJob. You can specify a dictionary as a
141+
mapping object representing the annotations.
136142
num_workers: Number of PyTorchJob workers.
137143
num_procs_per_worker: Number of processes per PyTorchJob worker for `torchrun` CLI. You
138144
should use this parameter if you want to use more than 1 GPU per PyTorchJob worker.
@@ -327,6 +333,8 @@ def train(
327333
job = utils.get_pytorchjob_template(
328334
name=name,
329335
namespace=namespace,
336+
labels=labels,
337+
annotations=annotations,
330338
master_pod_template_spec=master_pod_template_spec,
331339
worker_pod_template_spec=worker_pod_template_spec,
332340
num_workers=num_workers,
@@ -340,6 +348,8 @@ def create_job(
340348
job: Optional[constants.JOB_MODELS_TYPE] = None,
341349
name: Optional[str] = None,
342350
namespace: Optional[str] = None,
351+
labels: Optional[Dict[str, str]] = None,
352+
annotations: Optional[Dict[str, str]] = None,
343353
job_kind: Optional[str] = None,
344354
base_image: Optional[str] = None,
345355
train_func: Optional[Callable] = None,
@@ -370,6 +380,10 @@ def create_job(
370380
name: Name for the Job. It must be set if `job` parameter is omitted.
371381
namespace: Namespace for the Job. By default namespace is taken from
372382
`TrainingClient` object.
383+
labels: Labels for the Job. You can specify a dictionary as a mapping
384+
object representing the labels.
385+
annotations: Annotations for the Job. You can specify a dictionary as
386+
a mapping object representing the annotations.
373387
job_kind: Kind for the Job (e.g. `TFJob` or `PyTorchJob`). It must be set if
374388
`job` parameter is omitted. By default Job kind is taken from
375389
`TrainingClient` object.
@@ -429,16 +443,25 @@ def create_job(
429443
RuntimeError: Failed to create Job.
430444
"""
431445

432-
# When Job is set, only namespace arg is allowed.
446+
# When Job is set, only namespace, labels, and annotations args are allowed.
433447
if job is not None:
434448
for key, value in locals().items():
435449
if (
436450
key
437-
not in ["self", "job", "namespace", "pip_index_url", "num_workers"]
451+
not in [
452+
"self",
453+
"job",
454+
"namespace",
455+
"pip_index_url",
456+
"num_workers",
457+
"labels",
458+
"annotations",
459+
]
438460
and value is not None
439461
):
440462
raise ValueError(
441-
"If `job` is set only `namespace` argument is allowed. "
463+
"If `job` is set only `namespace`, `labels`, and `annotations` "
464+
"arguments are allowed. "
442465
f"Argument `{key}` must be None."
443466
)
444467

@@ -525,6 +548,8 @@ def create_job(
525548
job = utils.get_tfjob_template(
526549
name=name,
527550
namespace=namespace,
551+
labels=labels,
552+
annotations=annotations,
528553
pod_template_spec=pod_template_spec,
529554
num_workers=num_workers,
530555
num_chief_replicas=num_chief_replicas,
@@ -539,6 +564,8 @@ def create_job(
539564
job = utils.get_pytorchjob_template(
540565
name=name,
541566
namespace=namespace,
567+
labels=labels,
568+
annotations=annotations,
542569
worker_pod_template_spec=pod_template_spec,
543570
num_workers=num_workers,
544571
num_procs_per_worker=num_procs_per_worker,

sdk/python/kubeflow/training/api/training_client_test.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def create_job(
143143
command=None,
144144
args=None,
145145
num_workers=2,
146+
labels=None,
147+
annotations=None,
146148
env_vars=None,
147149
volumes=None,
148150
volume_mounts=None,
@@ -195,10 +197,16 @@ def create_job(
195197
),
196198
)
197199

200+
meta_kwargs = {"name": TEST_NAME, "namespace": TEST_NAME}
201+
if labels is not None:
202+
meta_kwargs["labels"] = labels
203+
if annotations is not None:
204+
meta_kwargs["annotations"] = annotations
205+
198206
pytorchjob = KubeflowOrgV1PyTorchJob(
199207
api_version=constants.API_VERSION,
200208
kind=constants.PYTORCHJOB_KIND,
201-
metadata=V1ObjectMeta(name=TEST_NAME, namespace=TEST_NAME),
209+
metadata=V1ObjectMeta(**meta_kwargs),
202210
spec=KubeflowOrgV1PyTorchJobSpec(
203211
run_policy=KubeflowOrgV1RunPolicy(clean_pod_policy=None),
204212
pytorch_replica_specs=pytorch_replica_specs,
@@ -570,6 +578,38 @@ def __init__(self):
570578
ValueError,
571579
None,
572580
),
581+
(
582+
"valid flow with labels",
583+
{
584+
"name": TEST_NAME,
585+
"namespace": TEST_NAME,
586+
"base_image": TEST_IMAGE,
587+
"num_workers": 1,
588+
"labels": {"upstream": "kubeflow", "component": "training"},
589+
},
590+
SUCCESS,
591+
create_job(
592+
num_workers=1,
593+
labels={"upstream": "kubeflow", "component": "training"},
594+
annotations=None,
595+
),
596+
),
597+
(
598+
"valid flow with annotations",
599+
{
600+
"name": TEST_NAME,
601+
"namespace": TEST_NAME,
602+
"base_image": TEST_IMAGE,
603+
"num_workers": 1,
604+
"annotations": {"purpose": "unit-test", "env": "test"},
605+
},
606+
SUCCESS,
607+
create_job(
608+
num_workers=1,
609+
labels=None,
610+
annotations={"purpose": "unit-test", "env": "test"},
611+
),
612+
),
573613
]
574614

575615
test_data_get_job_pods = [

sdk/python/kubeflow/training/utils/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ def get_tfjob_template(
273273
namespace: str,
274274
pod_template_spec: models.V1PodTemplateSpec,
275275
num_workers: int,
276+
labels: Optional[Dict[str, str]] = None,
277+
annotations: Optional[Dict[str, str]] = None,
276278
num_chief_replicas: Optional[int] = None,
277279
num_ps_replicas: Optional[int] = None,
278280
):
@@ -297,6 +299,12 @@ def get_tfjob_template(
297299
)
298300
)
299301

302+
if labels is not None:
303+
tfjob.metadata.labels = labels
304+
305+
if annotations is not None:
306+
tfjob.metadata.annotations = annotations
307+
300308
if num_ps_replicas is not None:
301309
tfjob.spec.tf_replica_specs[constants.REPLICA_TYPE_PS] = (
302310
models.KubeflowOrgV1ReplicaSpec(
@@ -321,6 +329,8 @@ def get_pytorchjob_template(
321329
namespace: str,
322330
num_workers: int,
323331
worker_pod_template_spec: Optional[models.V1PodTemplateSpec],
332+
labels: Optional[Dict[str, str]] = None,
333+
annotations: Optional[Dict[str, str]] = None,
324334
master_pod_template_spec: Optional[models.V1PodTemplateSpec] = None,
325335
num_procs_per_worker: Optional[Union[int, str]] = None,
326336
):
@@ -336,7 +346,13 @@ def get_pytorchjob_template(
336346
),
337347
)
338348

339-
if num_procs_per_worker:
349+
if labels is not None:
350+
pytorchjob.metadata.labels = labels
351+
352+
if annotations is not None:
353+
pytorchjob.metadata.annotations = annotations
354+
355+
if num_procs_per_worker is not None:
340356
pytorchjob.spec.nproc_per_node = str(num_procs_per_worker)
341357

342358
# Create Master replica if that is set.

0 commit comments

Comments
 (0)