Skip to content

Commit e9007a7

Browse files
authored
Merge branch 'main' into feat/k8s_pod_overlay
2 parents c242520 + a01f925 commit e9007a7

File tree

2 files changed

+194
-20
lines changed

2 files changed

+194
-20
lines changed

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def app_to_resource(
474474
queue: str,
475475
service_account: Optional[str],
476476
priority_class: Optional[str] = None,
477-
) -> Dict[str, object]:
477+
) -> Dict[str, Any]:
478478
"""
479479
app_to_resource creates a volcano job kubernetes resource definition from
480480
the provided AppDef. The resource definition can be used to launch the
@@ -560,7 +560,7 @@ def app_to_resource(
560560
if priority_class is not None:
561561
job_spec["priorityClassName"] = priority_class
562562

563-
resource: Dict[str, object] = {
563+
resource: Dict[str, Any] = {
564564
"apiVersion": "batch.volcano.sh/v1alpha1",
565565
"kind": "Job",
566566
"metadata": {"name": f"{unique_app_id}"},
@@ -572,7 +572,7 @@ def app_to_resource(
572572
@dataclass
573573
class KubernetesJob:
574574
images_to_push: Dict[str, Tuple[str, str]]
575-
resource: Dict[str, object]
575+
resource: Dict[str, Any]
576576

577577
def __str__(self) -> str:
578578
return yaml.dump(sanitize_for_serialization(self.resource))
@@ -587,6 +587,7 @@ class KubernetesOpts(TypedDict, total=False):
587587
image_repo: Optional[str]
588588
service_account: Optional[str]
589589
priority_class: Optional[str]
590+
validate_spec: Optional[bool]
590591

591592

592593
class KubernetesScheduler(
@@ -775,6 +776,36 @@ def _submit_dryrun(
775776
), "priority_class must be a str"
776777

777778
resource = app_to_resource(app, queue, service_account, priority_class)
779+
780+
if cfg.get("validate_spec"):
781+
try:
782+
self._custom_objects_api().create_namespaced_custom_object(
783+
group="batch.volcano.sh",
784+
version="v1alpha1",
785+
namespace=cfg.get("namespace") or "default",
786+
plural="jobs",
787+
body=resource,
788+
dry_run="All",
789+
)
790+
except Exception as e:
791+
from kubernetes.client.rest import ApiException
792+
793+
if isinstance(e, ApiException):
794+
raise ValueError(f"Invalid job spec: {e.reason}") from e
795+
raise
796+
797+
job_name = resource["metadata"]["name"]
798+
for task in resource["spec"]["tasks"]:
799+
task_name = task["name"]
800+
replicas = task.get("replicas", 1)
801+
max_index = replicas - 1
802+
pod_name = f"{job_name}-{task_name}-{max_index}"
803+
if len(pod_name) > 63:
804+
raise ValueError(
805+
f"Pod name '{pod_name}' ({len(pod_name)} chars) exceeds 63 character limit. "
806+
f"Shorten app.name or role names"
807+
)
808+
778809
req = KubernetesJob(
779810
resource=resource,
780811
images_to_push=images_to_push,
@@ -819,6 +850,12 @@ def _run_opts(self) -> runopts:
819850
type_=str,
820851
help="The name of the PriorityClass to set on the job specs",
821852
)
853+
opts.add(
854+
"validate_spec",
855+
type_=bool,
856+
help="Validate job spec using Kubernetes API dry-run before submission",
857+
default=True,
858+
)
822859
return opts
823860

824861
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:

torchx/schedulers/test/kubernetes_scheduler_test.py

Lines changed: 154 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import sys
1212
import unittest
1313
from datetime import datetime
14-
from typing import Any, Dict
14+
from typing import Any, cast, Dict
1515
from unittest.mock import MagicMock, patch
1616

1717
import torchx
@@ -111,10 +111,7 @@ def test_app_to_resource_resolved_macros(self) -> None:
111111
make_unique_ctx.return_value = unique_app_name
112112
resource = app_to_resource(app, "test_queue", service_account=None)
113113
actual_cmd = (
114-
# pyre-ignore [16]
115-
resource["spec"]["tasks"][0]["template"]
116-
.spec.containers[0]
117-
.command
114+
resource["spec"]["tasks"][0]["template"].spec.containers[0].command
118115
)
119116
expected_cmd = [
120117
"main",
@@ -135,7 +132,6 @@ def test_retry_policy_not_set(self) -> None:
135132
{"event": "PodEvicted", "action": "RestartJob"},
136133
{"event": "PodFailed", "action": "RestartJob"},
137134
],
138-
# pyre-ignore [16]
139135
resource["spec"]["tasks"][0]["policies"],
140136
)
141137
for role in app.roles:
@@ -251,7 +247,11 @@ def test_role_to_pod(self) -> None:
251247
want,
252248
)
253249

254-
def test_submit_dryrun(self) -> None:
250+
@patch(
251+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
252+
)
253+
def test_submit_dryrun(self, mock_api: MagicMock) -> None:
254+
mock_api.return_value.create_namespaced_custom_object.return_value = {}
255255
scheduler = create_scheduler("test")
256256
app = _test_app()
257257
cfg = KubernetesOpts({"queue": "testqueue"})
@@ -262,6 +262,9 @@ def test_submit_dryrun(self) -> None:
262262
info = scheduler.submit_dryrun(app, cfg)
263263

264264
resource = str(info.request)
265+
mock_api.return_value.create_namespaced_custom_object.assert_called_once()
266+
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
267+
self.assertEqual(call_kwargs["dry_run"], "All")
265268

266269
print(resource)
267270

@@ -505,7 +508,11 @@ def test_instance_type(self) -> None:
505508
},
506509
)
507510

508-
def test_rank0_env(self) -> None:
511+
@patch(
512+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
513+
)
514+
def test_rank0_env(self, mock_api: MagicMock) -> None:
515+
mock_api.return_value.create_namespaced_custom_object.return_value = {}
509516
from kubernetes.client.models import V1EnvVar
510517

511518
scheduler = create_scheduler("test")
@@ -517,7 +524,7 @@ def test_rank0_env(self) -> None:
517524
make_unique_ctx.return_value = "app-name-42"
518525
info = scheduler.submit_dryrun(app, cfg)
519526

520-
tasks = info.request.resource["spec"]["tasks"] # pyre-ignore[16]
527+
tasks = info.request.resource["spec"]["tasks"]
521528
container0 = tasks[0]["template"].spec.containers[0]
522529
self.assertIn("TORCHX_RANK0_HOST", container0.command)
523530
self.assertIn(
@@ -528,8 +535,16 @@ def test_rank0_env(self) -> None:
528535
)
529536
container1 = tasks[1]["template"].spec.containers[0]
530537
self.assertIn("VC_TRAINERFOO_0_HOSTS", container1.command)
538+
mock_api.return_value.create_namespaced_custom_object.assert_called_once()
539+
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
540+
self.assertEqual(call_kwargs["dry_run"], "All")
541+
self.assertEqual(call_kwargs["namespace"], "default")
531542

532-
def test_submit_dryrun_patch(self) -> None:
543+
@patch(
544+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
545+
)
546+
def test_submit_dryrun_patch(self, mock_api: MagicMock) -> None:
547+
mock_api.return_value.create_namespaced_custom_object.return_value = {}
533548
scheduler = create_scheduler("test")
534549
app = _test_app()
535550
app.roles[0].image = "sha256:testhash"
@@ -555,8 +570,15 @@ def test_submit_dryrun_patch(self) -> None:
555570
),
556571
},
557572
)
573+
mock_api.return_value.create_namespaced_custom_object.assert_called_once()
574+
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
575+
self.assertEqual(call_kwargs["dry_run"], "All")
558576

559-
def test_submit_dryrun_service_account(self) -> None:
577+
@patch(
578+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
579+
)
580+
def test_submit_dryrun_service_account(self, mock_api: MagicMock) -> None:
581+
mock_api.return_value.create_namespaced_custom_object.return_value = {}
560582
scheduler = create_scheduler("test")
561583
self.assertIn("service_account", scheduler.run_opts()._opts)
562584
app = _test_app()
@@ -573,7 +595,17 @@ def test_submit_dryrun_service_account(self) -> None:
573595
info = scheduler.submit_dryrun(app, cfg)
574596
self.assertIn("service_account_name': None", str(info.request.resource))
575597

576-
def test_submit_dryrun_priority_class(self) -> None:
598+
self.assertEqual(
599+
mock_api.return_value.create_namespaced_custom_object.call_count, 2
600+
)
601+
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
602+
self.assertEqual(call_kwargs["dry_run"], "All")
603+
604+
@patch(
605+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
606+
)
607+
def test_submit_dryrun_priority_class(self, mock_api: MagicMock) -> None:
608+
mock_api.return_value.create_namespaced_custom_object.return_value = {}
577609
scheduler = create_scheduler("test")
578610
self.assertIn("priority_class", scheduler.run_opts()._opts)
579611
app = _test_app()
@@ -591,6 +623,12 @@ def test_submit_dryrun_priority_class(self) -> None:
591623
info = scheduler.submit_dryrun(app, cfg)
592624
self.assertNotIn("'priorityClassName'", str(info.request.resource))
593625

626+
self.assertEqual(
627+
mock_api.return_value.create_namespaced_custom_object.call_count, 2
628+
)
629+
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
630+
self.assertEqual(call_kwargs["dry_run"], "All")
631+
594632
@patch("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object")
595633
def test_submit(self, create_namespaced_custom_object: MagicMock) -> None:
596634
create_namespaced_custom_object.return_value = {
@@ -624,7 +662,7 @@ def test_submit_job_name_conflict(
624662

625663
api_exc = ApiException(status=409, reason="Conflict")
626664
api_exc.body = '{"details":{"name": "test_job"}}'
627-
create_namespaced_custom_object.side_effect = api_exc
665+
create_namespaced_custom_object.side_effect = [{}, api_exc]
628666

629667
scheduler = create_scheduler("test")
630668
app = _test_app()
@@ -638,6 +676,14 @@ def test_submit_job_name_conflict(
638676
with self.assertRaises(ValueError):
639677
scheduler.schedule(info)
640678

679+
self.assertEqual(create_namespaced_custom_object.call_count, 2)
680+
# First call is spec validation
681+
first_call_kwargs = create_namespaced_custom_object.call_args_list[0][1]
682+
self.assertEqual(first_call_kwargs["dry_run"], "All")
683+
# Second call is actual schedule
684+
second_call_kwargs = create_namespaced_custom_object.call_args_list[1][1]
685+
self.assertNotIn("dry_run", second_call_kwargs)
686+
641687
@patch("kubernetes.client.CustomObjectsApi.get_namespaced_custom_object_status")
642688
def test_describe(self, get_namespaced_custom_object_status: MagicMock) -> None:
643689
get_namespaced_custom_object_status.return_value = {
@@ -752,6 +798,7 @@ def test_runopts(self) -> None:
752798
"image_repo",
753799
"service_account",
754800
"priority_class",
801+
"validate_spec",
755802
},
756803
)
757804

@@ -949,10 +996,7 @@ def test_min_replicas(self) -> None:
949996
app.roles[0].min_replicas = 2
950997

951998
resource = app_to_resource(app, "test_queue", service_account=None)
952-
min_available = [
953-
task["minAvailable"]
954-
for task in resource["spec"]["tasks"] # pyre-ignore[16]
955-
]
999+
min_available = [task["minAvailable"] for task in resource["spec"]["tasks"]]
9561000
self.assertEqual(min_available, [1, 1, 0])
9571001

9581002
def test_apply_pod_overlay(self) -> None:
@@ -1205,6 +1249,99 @@ def test_submit_dryrun_with_pod_overlay_invalid_type(self) -> None:
12051249
scheduler.submit_dryrun(app, cfg)
12061250

12071251
self.assertIn("must be a dict or resource URI", str(ctx.exception))
1252+
@patch(
1253+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1254+
)
1255+
def test_validate_spec_invalid_name(self, mock_api: MagicMock) -> None:
1256+
from kubernetes.client.rest import ApiException
1257+
1258+
scheduler = create_scheduler("test")
1259+
app = _test_app()
1260+
app.name = "Invalid_Name"
1261+
1262+
mock_api_instance = MagicMock()
1263+
mock_api_instance.create_namespaced_custom_object.side_effect = ApiException(
1264+
status=422,
1265+
reason="Invalid",
1266+
)
1267+
mock_api.return_value = mock_api_instance
1268+
1269+
cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True})
1270+
1271+
with self.assertRaises(ValueError) as ctx:
1272+
scheduler.submit_dryrun(app, cfg)
1273+
1274+
self.assertIn("Invalid job spec", str(ctx.exception))
1275+
mock_api_instance.create_namespaced_custom_object.assert_called_once()
1276+
call_kwargs = mock_api_instance.create_namespaced_custom_object.call_args[1]
1277+
self.assertEqual(call_kwargs["dry_run"], "All")
1278+
1279+
def test_validate_spec_disabled(self) -> None:
1280+
scheduler = create_scheduler("test")
1281+
app = _test_app()
1282+
1283+
cfg = KubernetesOpts({"queue": "testqueue", "validate_spec": False})
1284+
1285+
with patch(
1286+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1287+
) as mock_api:
1288+
mock_api_instance = MagicMock()
1289+
mock_api_instance.create_namespaced_custom_object.return_value = {}
1290+
mock_api.return_value = mock_api_instance
1291+
1292+
info = scheduler.submit_dryrun(app, cfg)
1293+
1294+
self.assertIsNotNone(info)
1295+
mock_api_instance.create_namespaced_custom_object.assert_not_called()
1296+
1297+
@patch(
1298+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1299+
)
1300+
def test_validate_spec_invalid_task_name(self, mock_api: MagicMock) -> None:
1301+
from kubernetes.client.rest import ApiException
1302+
1303+
scheduler = create_scheduler("test")
1304+
app = _test_app()
1305+
app.roles[0].name = "Invalid-Task-Name"
1306+
1307+
mock_api_instance = MagicMock()
1308+
mock_api_instance.create_namespaced_custom_object.side_effect = ApiException(
1309+
status=422,
1310+
reason="Invalid",
1311+
)
1312+
mock_api.return_value = mock_api_instance
1313+
1314+
cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True})
1315+
1316+
with self.assertRaises(ValueError) as ctx:
1317+
scheduler.submit_dryrun(app, cfg)
1318+
1319+
self.assertIn("Invalid job spec", str(ctx.exception))
1320+
1321+
@patch(
1322+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1323+
)
1324+
def test_validate_spec_long_pod_name(self, mock_api: MagicMock) -> None:
1325+
scheduler = create_scheduler("test")
1326+
app = _test_app()
1327+
app.name = "x" * 50
1328+
app.roles[0].name = "y" * 20
1329+
1330+
mock_api_instance = MagicMock()
1331+
mock_api_instance.create_namespaced_custom_object.return_value = {}
1332+
mock_api.return_value = mock_api_instance
1333+
1334+
cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True})
1335+
1336+
with patch(
1337+
"torchx.schedulers.kubernetes_scheduler.make_unique"
1338+
) as make_unique_ctx:
1339+
make_unique_ctx.return_value = "x" * 50
1340+
with self.assertRaises(ValueError) as ctx:
1341+
scheduler.submit_dryrun(app, cfg)
1342+
1343+
self.assertIn("Pod name", str(ctx.exception))
1344+
self.assertIn("exceeds 63 character limit", str(ctx.exception))
12081345

12091346

12101347
class KubernetesSchedulerNoImportTest(unittest.TestCase):

0 commit comments

Comments
 (0)