Skip to content

Commit e9891d2

Browse files
authored
Merge branch 'main' into feat/extend_schedulers_list
2 parents 36e77f7 + 7abb9a2 commit e9891d2

File tree

7 files changed

+575
-169
lines changed

7 files changed

+575
-169
lines changed

setup.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414

1515

1616
def get_version():
17-
# get version string from version.py
18-
# TODO: ideally the version.py should be generated when setup is run
19-
version_file = os.path.join(os.path.dirname(__file__), "torchx/version.py")
20-
version_regex = r"__version__ = ['\"]([^'\"]*)['\"]"
17+
# get version string from _version.py
18+
version_file = os.path.join(os.path.dirname(__file__), "torchx/_version.py")
19+
version_regex = r"BASE_VERSION = ['\"]([^'\"]*)['\"]"
2120
with open(version_file, "r") as f:
2221
version = re.search(version_regex, f.read(), re.M).group(1)
2322
return version

torchx/_version.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
BASE_VERSION = "0.8.0dev0"

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 157 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,76 @@
2727
See the
2828
`Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
2929
for more information.
30+
31+
Pod Overlay
32+
===========
33+
34+
You can overlay arbitrary Kubernetes Pod fields on generated pods by setting
35+
the ``kubernetes`` metadata on your role. The value can be:
36+
37+
- A dict with the overlay structure
38+
- A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``)
39+
40+
Merge semantics:
41+
- **dict**: recursive merge (upsert)
42+
- **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML)
43+
- **primitives**: replace
44+
45+
.. code:: python
46+
47+
from torchx.specs import Role
48+
49+
# Dict overlay - lists append, tuples replace
50+
role = Role(
51+
name="trainer",
52+
image="my-image:latest",
53+
entrypoint="train.py",
54+
metadata={
55+
"kubernetes": {
56+
"spec": {
57+
"nodeSelector": {"gpu": "true"},
58+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends
59+
"volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces
60+
}
61+
}
62+
}
63+
)
64+
65+
# File URI overlay
66+
role = Role(
67+
name="trainer",
68+
image="my-image:latest",
69+
entrypoint="train.py",
70+
metadata={
71+
"kubernetes": "file:///path/to/pod_overlay.yaml"
72+
}
73+
)
74+
75+
CLI usage with builtin components:
76+
77+
.. code:: bash
78+
79+
$ torchx run --scheduler kubernetes dist.ddp \\
80+
--metadata kubernetes=file:///path/to/pod_overlay.yaml \\
81+
--script train.py
82+
83+
Example ``pod_overlay.yaml``:
84+
85+
.. code:: yaml
86+
87+
spec:
88+
nodeSelector:
89+
node.kubernetes.io/instance-type: p4d.24xlarge
90+
tolerations:
91+
- key: nvidia.com/gpu
92+
operator: Exists
93+
effect: NoSchedule
94+
volumes: !!python/tuple
95+
- name: my-volume
96+
emptyDir: {}
97+
98+
The overlay is deep-merged with the generated pod, preserving existing fields
99+
and adding or overriding specified ones.
30100
"""
31101

32102
import json
@@ -45,6 +115,7 @@
45115
Tuple,
46116
TYPE_CHECKING,
47117
TypedDict,
118+
Union,
48119
)
49120

50121
import torchx
@@ -97,6 +168,40 @@
97168
RESERVED_MILLICPU = 100
98169
RESERVED_MEMMB = 1024
99170

171+
172+
def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
173+
"""Apply overlay dict to V1Pod object, merging nested fields.
174+
175+
Merge semantics:
176+
- dict: upsert (recursive merge)
177+
- list: append by default, replace if tuple
178+
- primitives: replace
179+
"""
180+
from kubernetes import client
181+
182+
api = client.ApiClient()
183+
pod_dict = api.sanitize_for_serialization(pod)
184+
185+
def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
186+
for key, value in overlay.items():
187+
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
188+
deep_merge(base[key], value)
189+
elif isinstance(value, tuple):
190+
base[key] = list(value)
191+
elif (
192+
isinstance(value, list) and key in base and isinstance(base[key], list)
193+
):
194+
base[key].extend(value)
195+
else:
196+
base[key] = value
197+
198+
deep_merge(pod_dict, overlay)
199+
200+
merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
201+
pod.spec = merged_pod.spec
202+
pod.metadata = merged_pod.metadata
203+
204+
100205
RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
101206
RetryPolicy.REPLICA: [],
102207
RetryPolicy.APPLICATION: [
@@ -369,7 +474,7 @@ def app_to_resource(
369474
queue: str,
370475
service_account: Optional[str],
371476
priority_class: Optional[str] = None,
372-
) -> Dict[str, object]:
477+
) -> Dict[str, Any]:
373478
"""
374479
app_to_resource creates a volcano job kubernetes resource definition from
375480
the provided AppDef. The resource definition can be used to launch the
@@ -402,6 +507,17 @@ def app_to_resource(
402507
replica_role.env["TORCHX_IMAGE"] = replica_role.image
403508

404509
pod = role_to_pod(name, replica_role, service_account)
510+
if k8s_metadata := role.metadata.get("kubernetes"):
511+
if isinstance(k8s_metadata, str):
512+
import fsspec
513+
514+
with fsspec.open(k8s_metadata, "r") as f:
515+
k8s_metadata = yaml.unsafe_load(f)
516+
elif not isinstance(k8s_metadata, dict):
517+
raise ValueError(
518+
f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}"
519+
)
520+
_apply_pod_overlay(pod, k8s_metadata)
405521
pod.metadata.labels.update(
406522
pod_labels(
407523
app=app,
@@ -444,7 +560,7 @@ def app_to_resource(
444560
if priority_class is not None:
445561
job_spec["priorityClassName"] = priority_class
446562

447-
resource: Dict[str, object] = {
563+
resource: Dict[str, Any] = {
448564
"apiVersion": "batch.volcano.sh/v1alpha1",
449565
"kind": "Job",
450566
"metadata": {"name": f"{unique_app_id}"},
@@ -456,7 +572,7 @@ def app_to_resource(
456572
@dataclass
457573
class KubernetesJob:
458574
images_to_push: Dict[str, Tuple[str, str]]
459-
resource: Dict[str, object]
575+
resource: Dict[str, Any]
460576

461577
def __str__(self) -> str:
462578
return yaml.dump(sanitize_for_serialization(self.resource))
@@ -471,6 +587,7 @@ class KubernetesOpts(TypedDict, total=False):
471587
image_repo: Optional[str]
472588
service_account: Optional[str]
473589
priority_class: Optional[str]
590+
validate_spec: Optional[bool]
474591

475592

476593
class KubernetesScheduler(
@@ -636,7 +753,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str:
636753
else:
637754
raise
638755

639-
return f'{namespace}:{resp["metadata"]["name"]}'
756+
return f"{namespace}:{resp['metadata']['name']}"
640757

641758
def _submit_dryrun(
642759
self, app: AppDef, cfg: KubernetesOpts
@@ -659,6 +776,36 @@ def _submit_dryrun(
659776
), "priority_class must be a str"
660777

661778
resource = app_to_resource(app, queue, service_account, priority_class)
779+
780+
if cfg.get("validate_spec"):
781+
try:
782+
self._custom_objects_api().create_namespaced_custom_object(
783+
group="batch.volcano.sh",
784+
version="v1alpha1",
785+
namespace=cfg.get("namespace") or "default",
786+
plural="jobs",
787+
body=resource,
788+
dry_run="All",
789+
)
790+
except Exception as e:
791+
from kubernetes.client.rest import ApiException
792+
793+
if isinstance(e, ApiException):
794+
raise ValueError(f"Invalid job spec: {e.reason}") from e
795+
raise
796+
797+
job_name = resource["metadata"]["name"]
798+
for task in resource["spec"]["tasks"]:
799+
task_name = task["name"]
800+
replicas = task.get("replicas", 1)
801+
max_index = replicas - 1
802+
pod_name = f"{job_name}-{task_name}-{max_index}"
803+
if len(pod_name) > 63:
804+
raise ValueError(
805+
f"Pod name '{pod_name}' ({len(pod_name)} chars) exceeds 63 character limit. "
806+
f"Shorten app.name or role names"
807+
)
808+
662809
req = KubernetesJob(
663810
resource=resource,
664811
images_to_push=images_to_push,
@@ -703,6 +850,12 @@ def _run_opts(self) -> runopts:
703850
type_=str,
704851
help="The name of the PriorityClass to set on the job specs",
705852
)
853+
opts.add(
854+
"validate_spec",
855+
type_=bool,
856+
help="Validate job spec using Kubernetes API dry-run before submission",
857+
default=True,
858+
)
706859
return opts
707860

708861
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:

0 commit comments

Comments
 (0)