Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions torchx/schedulers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ def __hash__(self) -> int:


T = TypeVar("T")
A = TypeVar("A")
D = TypeVar("D")


class Scheduler(abc.ABC, Generic[T]):
class Scheduler(abc.ABC, Generic[T, A, D]):
"""
An interface abstracting functionalities of a scheduler.
Implementers need only implement those methods annotated with
Expand Down Expand Up @@ -126,7 +128,7 @@ def close(self) -> None:

def submit(
self,
app: AppDef,
app: A,
cfg: T,
workspace: Optional[str] = None,
) -> str:
Expand All @@ -150,7 +152,7 @@ def submit(
return self.schedule(dryrun_info)

@abc.abstractmethod
def schedule(self, dryrun_info: AppDryRunInfo) -> str:
def schedule(self, dryrun_info: D) -> str:
"""
Same as ``submit`` except that it takes an ``AppDryRunInfo``.
Implementers are encouraged to implement this method rather than
Expand All @@ -166,7 +168,7 @@ def schedule(self, dryrun_info: AppDryRunInfo) -> str:

raise NotImplementedError()

def submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
def submit_dryrun(self, app: A, cfg: T) -> D:
"""
Rather than submitting the request to run the app, returns the
request object that would have been submitted to the underlying
Expand All @@ -179,14 +181,16 @@ def submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
resolved_cfg = self.run_opts().resolve(cfg)
# pyre-fixme: _submit_dryrun takes Generic type for resolved_cfg
dryrun_info = self._submit_dryrun(app, resolved_cfg)
for role in app.roles:
dryrun_info = role.pre_proc(self.backend, dryrun_info)
dryrun_info._app = app
dryrun_info._cfg = resolved_cfg

if isinstance(app, AppDef):
for role in app.roles:
dryrun_info = role.pre_proc(self.backend, dryrun_info)
dryrun_info._app = app
dryrun_info._cfg = resolved_cfg
return dryrun_info

@abc.abstractmethod
def _submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
def _submit_dryrun(self, app: A, cfg: T) -> D:
raise NotImplementedError()

def run_opts(self) -> runopts:
Expand Down Expand Up @@ -345,18 +349,19 @@ def _pre_build_validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
"""
pass

def _validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
def _validate(self, app: A, scheduler: str, cfg: T) -> None:
"""
Validates after workspace build whether application is consistent with the scheduler.

Raises error if application is not compatible with scheduler
"""
for role in app.roles:
if role.resource == NULL_RESOURCE:
raise ValueError(
f"No resource for role: {role.image}."
f" Did you forget to attach resource to the role"
)
if isinstance(app, AppDef):
for role in app.roles:
if role.resource == NULL_RESOURCE:
raise ValueError(
f"No resource for role: {role.image}."
f" Did you forget to attach resource to the role"
)


def filter_regex(regex: str, data: Iterable[str]) -> Iterable[str]:
Expand Down
4 changes: 3 additions & 1 deletion torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,9 @@ class AWSBatchOpts(TypedDict, total=False):
execution_role_arn: Optional[str]


class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]):
class AWSBatchScheduler(
DockerWorkspaceMixin, Scheduler[AWSBatchOpts, AppDef, AppDryRunInfo[BatchJob]]
):
"""
AWSBatchScheduler is a TorchX scheduling interface to AWS Batch.

Expand Down
5 changes: 4 additions & 1 deletion torchx/schedulers/aws_sagemaker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def _merge_ordered(
return merged


class AWSSageMakerScheduler(DockerWorkspaceMixin, Scheduler[AWSSageMakerOpts]): # type: ignore[misc]
class AWSSageMakerScheduler(
DockerWorkspaceMixin,
Scheduler[AWSSageMakerOpts, AppDef, AppDryRunInfo[AWSSageMakerJob]],
):
"""
AWSSageMakerScheduler is a TorchX scheduling interface to AWS SageMaker.

Expand Down
4 changes: 3 additions & 1 deletion torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ class DockerOpts(TypedDict, total=False):
privileged: bool


class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
class DockerScheduler(
DockerWorkspaceMixin, Scheduler[DockerOpts, AppDef, AppDryRunInfo[DockerJob]]
):
"""
DockerScheduler is a TorchX scheduling interface to Docker.
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/gcp_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class GCPBatchOpts(TypedDict, total=False):
location: Optional[str]


class GCPBatchScheduler(Scheduler[GCPBatchOpts]):
class GCPBatchScheduler(Scheduler[GCPBatchOpts, AppDef, AppDryRunInfo[GCPBatchJob]]):
"""
GCPBatchScheduler is a TorchX scheduling interface to GCP Batch.

Expand Down
4 changes: 3 additions & 1 deletion torchx/schedulers/kubernetes_mcad_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,9 @@ class KubernetesMCADOpts(TypedDict, total=False):
network: Optional[str]


class KubernetesMCADScheduler(DockerWorkspaceMixin, Scheduler[KubernetesMCADOpts]):
class KubernetesMCADScheduler(
DockerWorkspaceMixin, Scheduler[KubernetesMCADOpts, AppDef, AppDryRunInfo]
):
"""
KubernetesMCADScheduler is a TorchX scheduling interface to Kubernetes.
Expand Down
5 changes: 4 additions & 1 deletion torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,10 @@ class KubernetesOpts(TypedDict, total=False):
priority_class: Optional[str]


class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]):
class KubernetesScheduler(
DockerWorkspaceMixin,
Scheduler[KubernetesOpts, AppDef, AppDryRunInfo[KubernetesJob]],
):
"""
KubernetesScheduler is a TorchX scheduling interface to Kubernetes.
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _register_termination_signals() -> None:
signal.signal(signal.SIGINT, _terminate_process_handler)


class LocalScheduler(Scheduler[LocalOpts]):
class LocalScheduler(Scheduler[LocalOpts, AppDef, AppDryRunInfo[PopenRequest]]):
"""
Schedules on localhost. Containers are modeled as processes and
certain properties of the container that are either not relevant
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/lsf_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def __repr__(self) -> str:
{self.materialize()}"""


class LsfScheduler(Scheduler[LsfOpts]):
class LsfScheduler(Scheduler[LsfOpts, AppDef, AppDryRunInfo]):
"""
**Example: hello_world**

Expand Down
4 changes: 3 additions & 1 deletion torchx/schedulers/ray_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ class RayJob:
requirements: Optional[str] = None
actors: List[RayActor] = field(default_factory=list)

class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]):
class RayScheduler(
TmpDirWorkspaceMixin, Scheduler[RayOpts, AppDef, AppDryRunInfo[RayJob]]
):
"""
RayScheduler is a TorchX scheduling interface to Ray. The job def
workers will be launched as Ray actors
Expand Down
4 changes: 3 additions & 1 deletion torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def __repr__(self) -> str:
{self.materialize()}"""


class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
class SlurmScheduler(
DirWorkspaceMixin, Scheduler[SlurmOpts, AppDef, AppDryRunInfo[SlurmBatchRequest]]
):
"""
SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects
that slurm CLI tools are locally installed and job accounting is enabled.
Expand Down
8 changes: 5 additions & 3 deletions torchx/schedulers/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
from torchx.workspace.api import WorkspaceMixin

T = TypeVar("T")
A = TypeVar("A")
D = TypeVar("D")


class SchedulerTest(unittest.TestCase):
class MockScheduler(Scheduler[T], WorkspaceMixin[None]):
class MockScheduler(Scheduler[T, A, D], WorkspaceMixin[None]):
def __init__(self, session_name: str) -> None:
super().__init__("mock", session_name)

Expand Down Expand Up @@ -151,7 +153,7 @@ def test_invalid_dryrun_cfg(self) -> None:

def test_role_preproc_called(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test_session")
app_mock = MagicMock()
app_mock = AppDef(name="test")
app_mock.roles = [MagicMock()]

cfg = {"foo": "bar"}
Expand All @@ -161,7 +163,7 @@ def test_role_preproc_called(self) -> None:

def test_validate(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test_session")
app_mock = MagicMock()
app_mock = AppDef(name="test")
app_mock.roles = [MagicMock()]
app_mock.roles[0].resource = NULL_RESOURCE

Expand Down
6 changes: 6 additions & 0 deletions torchx/schedulers/test/aws_batch_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def test_submit_dryrun_tags(self, _) -> None:
def test_submit_dryrun_job_role_arn(self) -> None:
cfg = AWSBatchOpts({"queue": "ignored_in_test", "job_role_arn": "fizzbuzz"})
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
# pyre-ignore[16]
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertEqual(cfg["job_role_arn"], node_groups[0]["container"]["jobRoleArn"])
Expand All @@ -166,6 +167,7 @@ def test_submit_dryrun_execution_role_arn(self) -> None:
{"queue": "ignored_in_test", "execution_role_arn": "veryexecutive"}
)
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
# pyre-ignore[16]
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertEqual(
Expand All @@ -175,6 +177,7 @@ def test_submit_dryrun_execution_role_arn(self) -> None:
def test_submit_dryrun_privileged(self) -> None:
cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
# pyre-ignore[16]
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertTrue(node_groups[0]["container"]["privileged"])
Expand All @@ -184,6 +187,7 @@ def test_submit_dryrun_instance_type_multinode(self) -> None:
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
app = _test_app(num_replicas=2, resource=resource)
info = create_scheduler("test").submit_dryrun(app, cfg)
# pyre-ignore[16]
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertEqual(
Expand All @@ -196,6 +200,7 @@ def test_submit_dryrun_no_instance_type_singlenode(self) -> None:
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
app = _test_app(num_replicas=1, resource=resource)
info = create_scheduler("test").submit_dryrun(app, cfg)
# pyre-ignore[16]
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertTrue("instanceType" not in node_groups[0]["container"])
Expand All @@ -205,6 +210,7 @@ def test_submit_dryrun_no_instance_type_non_aws(self) -> None:
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
app = _test_app(num_replicas=2)
info = create_scheduler("test").submit_dryrun(app, cfg)
# pyre-ignore[16]
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertTrue("instanceType" not in node_groups[0]["container"])
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/test/kubernetes_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def test_rank0_env(self) -> None:
make_unique_ctx.return_value = "app-name-42"
info = scheduler.submit_dryrun(app, cfg)

tasks = info.request.resource["spec"]["tasks"]
tasks = info.request.resource["spec"]["tasks"] # pyre-ignore[16]
container0 = tasks[0]["template"].spec.containers[0]
self.assertIn("TORCHX_RANK0_HOST", container0.command)
self.assertIn(
Expand Down
Loading