Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def app_to_resource(
queue: str,
service_account: Optional[str],
priority_class: Optional[str] = None,
) -> Dict[str, object]:
) -> Dict[str, Any]:
"""
app_to_resource creates a volcano job kubernetes resource definition from
the provided AppDef. The resource definition can be used to launch the
Expand Down Expand Up @@ -444,7 +444,7 @@ def app_to_resource(
if priority_class is not None:
job_spec["priorityClassName"] = priority_class

resource: Dict[str, object] = {
resource: Dict[str, Any] = {
"apiVersion": "batch.volcano.sh/v1alpha1",
"kind": "Job",
"metadata": {"name": f"{unique_app_id}"},
Expand All @@ -456,7 +456,7 @@ def app_to_resource(
@dataclass
class KubernetesJob:
images_to_push: Dict[str, Tuple[str, str]]
resource: Dict[str, object]
resource: Dict[str, Any]

def __str__(self) -> str:
return yaml.dump(sanitize_for_serialization(self.resource))
Expand All @@ -471,6 +471,7 @@ class KubernetesOpts(TypedDict, total=False):
image_repo: Optional[str]
service_account: Optional[str]
priority_class: Optional[str]
validate_spec: Optional[bool]


class KubernetesScheduler(
Expand Down Expand Up @@ -659,6 +660,36 @@ def _submit_dryrun(
), "priority_class must be a str"

resource = app_to_resource(app, queue, service_account, priority_class)

if cfg.get("validate_spec"):
try:
self._custom_objects_api().create_namespaced_custom_object(
group="batch.volcano.sh",
version="v1alpha1",
namespace=cfg.get("namespace") or "default",
plural="jobs",
body=resource,
dry_run="All",
)
except Exception as e:
from kubernetes.client.rest import ApiException

if isinstance(e, ApiException):
raise ValueError(f"Invalid job spec: {e.reason}") from e
raise

job_name = resource["metadata"]["name"]
for task in resource["spec"]["tasks"]:
task_name = task["name"]
replicas = task.get("replicas", 1)
max_index = replicas - 1
pod_name = f"{job_name}-{task_name}-{max_index}"
if len(pod_name) > 63:
raise ValueError(
f"Pod name '{pod_name}' ({len(pod_name)} chars) exceeds 63 character limit. "
f"Shorten app.name or role names"
)

req = KubernetesJob(
resource=resource,
images_to_push=images_to_push,
Expand Down Expand Up @@ -703,6 +734,12 @@ def _run_opts(self) -> runopts:
type_=str,
help="The name of the PriorityClass to set on the job specs",
)
opts.add(
"validate_spec",
type_=bool,
help="Validate job spec using Kubernetes API dry-run before submission",
default=True,
)
return opts

def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
Expand Down
172 changes: 155 additions & 17 deletions torchx/schedulers/test/kubernetes_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
import unittest
from datetime import datetime
from typing import Any, Dict
from typing import Any, cast, Dict
from unittest.mock import MagicMock, patch

import torchx
Expand Down Expand Up @@ -111,10 +111,7 @@ def test_app_to_resource_resolved_macros(self) -> None:
make_unique_ctx.return_value = unique_app_name
resource = app_to_resource(app, "test_queue", service_account=None)
actual_cmd = (
# pyre-ignore [16]
resource["spec"]["tasks"][0]["template"]
.spec.containers[0]
.command
resource["spec"]["tasks"][0]["template"].spec.containers[0].command
)
expected_cmd = [
"main",
Expand All @@ -135,7 +132,6 @@ def test_retry_policy_not_set(self) -> None:
{"event": "PodEvicted", "action": "RestartJob"},
{"event": "PodFailed", "action": "RestartJob"},
],
# pyre-ignore [16]
resource["spec"]["tasks"][0]["policies"],
)
for role in app.roles:
Expand Down Expand Up @@ -251,7 +247,11 @@ def test_role_to_pod(self) -> None:
want,
)

def test_submit_dryrun(self) -> None:
@patch(
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
)
def test_submit_dryrun(self, mock_api: MagicMock) -> None:
mock_api.return_value.create_namespaced_custom_object.return_value = {}
scheduler = create_scheduler("test")
app = _test_app()
cfg = KubernetesOpts({"queue": "testqueue"})
Expand All @@ -262,6 +262,9 @@ def test_submit_dryrun(self) -> None:
info = scheduler.submit_dryrun(app, cfg)

resource = str(info.request)
mock_api.return_value.create_namespaced_custom_object.assert_called_once()
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
self.assertEqual(call_kwargs["dry_run"], "All")

print(resource)

Expand Down Expand Up @@ -505,7 +508,11 @@ def test_instance_type(self) -> None:
},
)

def test_rank0_env(self) -> None:
@patch(
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
)
def test_rank0_env(self, mock_api: MagicMock) -> None:
mock_api.return_value.create_namespaced_custom_object.return_value = {}
from kubernetes.client.models import V1EnvVar

scheduler = create_scheduler("test")
Expand All @@ -517,7 +524,7 @@ def test_rank0_env(self) -> None:
make_unique_ctx.return_value = "app-name-42"
info = scheduler.submit_dryrun(app, cfg)

tasks = info.request.resource["spec"]["tasks"] # pyre-ignore[16]
tasks = info.request.resource["spec"]["tasks"]
container0 = tasks[0]["template"].spec.containers[0]
self.assertIn("TORCHX_RANK0_HOST", container0.command)
self.assertIn(
Expand All @@ -528,8 +535,16 @@ def test_rank0_env(self) -> None:
)
container1 = tasks[1]["template"].spec.containers[0]
self.assertIn("VC_TRAINERFOO_0_HOSTS", container1.command)
mock_api.return_value.create_namespaced_custom_object.assert_called_once()
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
self.assertEqual(call_kwargs["dry_run"], "All")
self.assertEqual(call_kwargs["namespace"], "default")

def test_submit_dryrun_patch(self) -> None:
@patch(
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
)
def test_submit_dryrun_patch(self, mock_api: MagicMock) -> None:
mock_api.return_value.create_namespaced_custom_object.return_value = {}
scheduler = create_scheduler("test")
app = _test_app()
app.roles[0].image = "sha256:testhash"
Expand All @@ -555,8 +570,15 @@ def test_submit_dryrun_patch(self) -> None:
),
},
)
mock_api.return_value.create_namespaced_custom_object.assert_called_once()
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
self.assertEqual(call_kwargs["dry_run"], "All")

def test_submit_dryrun_service_account(self) -> None:
@patch(
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
)
def test_submit_dryrun_service_account(self, mock_api: MagicMock) -> None:
mock_api.return_value.create_namespaced_custom_object.return_value = {}
scheduler = create_scheduler("test")
self.assertIn("service_account", scheduler.run_opts()._opts)
app = _test_app()
Expand All @@ -573,7 +595,17 @@ def test_submit_dryrun_service_account(self) -> None:
info = scheduler.submit_dryrun(app, cfg)
self.assertIn("service_account_name': None", str(info.request.resource))

def test_submit_dryrun_priority_class(self) -> None:
self.assertEqual(
mock_api.return_value.create_namespaced_custom_object.call_count, 2
)
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
self.assertEqual(call_kwargs["dry_run"], "All")

@patch(
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
)
def test_submit_dryrun_priority_class(self, mock_api: MagicMock) -> None:
mock_api.return_value.create_namespaced_custom_object.return_value = {}
scheduler = create_scheduler("test")
self.assertIn("priority_class", scheduler.run_opts()._opts)
app = _test_app()
Expand All @@ -591,6 +623,12 @@ def test_submit_dryrun_priority_class(self) -> None:
info = scheduler.submit_dryrun(app, cfg)
self.assertNotIn("'priorityClassName'", str(info.request.resource))

self.assertEqual(
mock_api.return_value.create_namespaced_custom_object.call_count, 2
)
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
self.assertEqual(call_kwargs["dry_run"], "All")

@patch("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object")
def test_submit(self, create_namespaced_custom_object: MagicMock) -> None:
create_namespaced_custom_object.return_value = {
Expand Down Expand Up @@ -624,7 +662,7 @@ def test_submit_job_name_conflict(

api_exc = ApiException(status=409, reason="Conflict")
api_exc.body = '{"details":{"name": "test_job"}}'
create_namespaced_custom_object.side_effect = api_exc
create_namespaced_custom_object.side_effect = [{}, api_exc]

scheduler = create_scheduler("test")
app = _test_app()
Expand All @@ -638,6 +676,14 @@ def test_submit_job_name_conflict(
with self.assertRaises(ValueError):
scheduler.schedule(info)

self.assertEqual(create_namespaced_custom_object.call_count, 2)
# First call is spec validation
first_call_kwargs = create_namespaced_custom_object.call_args_list[0][1]
self.assertEqual(first_call_kwargs["dry_run"], "All")
# Second call is actual schedule
second_call_kwargs = create_namespaced_custom_object.call_args_list[1][1]
self.assertNotIn("dry_run", second_call_kwargs)

@patch("kubernetes.client.CustomObjectsApi.get_namespaced_custom_object_status")
def test_describe(self, get_namespaced_custom_object_status: MagicMock) -> None:
get_namespaced_custom_object_status.return_value = {
Expand Down Expand Up @@ -752,6 +798,7 @@ def test_runopts(self) -> None:
"image_repo",
"service_account",
"priority_class",
"validate_spec",
},
)

Expand Down Expand Up @@ -949,12 +996,103 @@ def test_min_replicas(self) -> None:
app.roles[0].min_replicas = 2

resource = app_to_resource(app, "test_queue", service_account=None)
min_available = [
task["minAvailable"]
for task in resource["spec"]["tasks"] # pyre-ignore[16]
]
min_available = [task["minAvailable"] for task in resource["spec"]["tasks"]]
self.assertEqual(min_available, [1, 1, 0])

@patch(
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
)
def test_validate_spec_invalid_name(self, mock_api: MagicMock) -> None:
from kubernetes.client.rest import ApiException

scheduler = create_scheduler("test")
app = _test_app()
app.name = "Invalid_Name"

mock_api_instance = MagicMock()
mock_api_instance.create_namespaced_custom_object.side_effect = ApiException(
status=422,
reason="Invalid",
)
mock_api.return_value = mock_api_instance

cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True})

with self.assertRaises(ValueError) as ctx:
scheduler.submit_dryrun(app, cfg)

self.assertIn("Invalid job spec", str(ctx.exception))
mock_api_instance.create_namespaced_custom_object.assert_called_once()
call_kwargs = mock_api_instance.create_namespaced_custom_object.call_args[1]
self.assertEqual(call_kwargs["dry_run"], "All")

def test_validate_spec_disabled(self) -> None:
scheduler = create_scheduler("test")
app = _test_app()

cfg = KubernetesOpts({"queue": "testqueue", "validate_spec": False})

with patch(
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
) as mock_api:
mock_api_instance = MagicMock()
mock_api_instance.create_namespaced_custom_object.return_value = {}
mock_api.return_value = mock_api_instance

info = scheduler.submit_dryrun(app, cfg)

self.assertIsNotNone(info)
mock_api_instance.create_namespaced_custom_object.assert_not_called()

@patch(
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
)
def test_validate_spec_invalid_task_name(self, mock_api: MagicMock) -> None:
from kubernetes.client.rest import ApiException

scheduler = create_scheduler("test")
app = _test_app()
app.roles[0].name = "Invalid-Task-Name"

mock_api_instance = MagicMock()
mock_api_instance.create_namespaced_custom_object.side_effect = ApiException(
status=422,
reason="Invalid",
)
mock_api.return_value = mock_api_instance

cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True})

with self.assertRaises(ValueError) as ctx:
scheduler.submit_dryrun(app, cfg)

self.assertIn("Invalid job spec", str(ctx.exception))

@patch(
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
)
def test_validate_spec_long_pod_name(self, mock_api: MagicMock) -> None:
scheduler = create_scheduler("test")
app = _test_app()
app.name = "x" * 50
app.roles[0].name = "y" * 20

mock_api_instance = MagicMock()
mock_api_instance.create_namespaced_custom_object.return_value = {}
mock_api.return_value = mock_api_instance

cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True})

with patch(
"torchx.schedulers.kubernetes_scheduler.make_unique"
) as make_unique_ctx:
make_unique_ctx.return_value = "x" * 50
with self.assertRaises(ValueError) as ctx:
scheduler.submit_dryrun(app, cfg)

self.assertIn("Pod name", str(ctx.exception))
self.assertIn("exceeds 63 character limit", str(ctx.exception))


class KubernetesSchedulerNoImportTest(unittest.TestCase):
"""
Expand Down
Loading