diff --git a/torchx/schedulers/api.py b/torchx/schedulers/api.py index 6a9e1bf70..bc9865014 100644 --- a/torchx/schedulers/api.py +++ b/torchx/schedulers/api.py @@ -94,9 +94,11 @@ def __hash__(self) -> int: T = TypeVar("T") +A = TypeVar("A") +D = TypeVar("D") -class Scheduler(abc.ABC, Generic[T]): +class Scheduler(abc.ABC, Generic[T, A, D]): """ An interface abstracting functionalities of a scheduler. Implementers need only implement those methods annotated with @@ -126,7 +128,7 @@ def close(self) -> None: def submit( self, - app: AppDef, + app: A, cfg: T, workspace: Optional[str] = None, ) -> str: @@ -150,7 +152,7 @@ def submit( return self.schedule(dryrun_info) @abc.abstractmethod - def schedule(self, dryrun_info: AppDryRunInfo) -> str: + def schedule(self, dryrun_info: D) -> str: """ Same as ``submit`` except that it takes an ``AppDryRunInfo``. Implementers are encouraged to implement this method rather than @@ -166,7 +168,7 @@ def schedule(self, dryrun_info: AppDryRunInfo) -> str: raise NotImplementedError() - def submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo: + def submit_dryrun(self, app: A, cfg: T) -> D: """ Rather than submitting the request to run the app, returns the request object that would have been submitted to the underlying @@ -179,14 +181,16 @@ def submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo: resolved_cfg = self.run_opts().resolve(cfg) # pyre-fixme: _submit_dryrun takes Generic type for resolved_cfg dryrun_info = self._submit_dryrun(app, resolved_cfg) - for role in app.roles: - dryrun_info = role.pre_proc(self.backend, dryrun_info) - dryrun_info._app = app - dryrun_info._cfg = resolved_cfg + + if isinstance(app, AppDef): + for role in app.roles: + dryrun_info = role.pre_proc(self.backend, dryrun_info) + dryrun_info._app = app + dryrun_info._cfg = resolved_cfg return dryrun_info @abc.abstractmethod - def _submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo: + def _submit_dryrun(self, app: A, cfg: T) -> D: raise NotImplementedError() def run_opts(self) -> runopts: @@ -345,18 +349,19 @@ def _pre_build_validate(self, app: AppDef, scheduler: str, cfg: T) -> None: """ pass - def _validate(self, app: AppDef, scheduler: str, cfg: T) -> None: + def _validate(self, app: A, scheduler: str, cfg: T) -> None: """ Validates after workspace build whether application is consistent with the scheduler. Raises error if application is not compatible with scheduler """ - for role in app.roles: - if role.resource == NULL_RESOURCE: - raise ValueError( - f"No resource for role: {role.image}." - f" Did you forget to attach resource to the role" - ) + if isinstance(app, AppDef): + for role in app.roles: + if role.resource == NULL_RESOURCE: + raise ValueError( + f"No resource for role: {role.image}." + f" Did you forget to attach resource to the role" + ) def filter_regex(regex: str, data: Iterable[str]) -> Iterable[str]: diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index 9577e90fb..ecd5ce2c1 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -363,7 +363,9 @@ class AWSBatchOpts(TypedDict, total=False): execution_role_arn: Optional[str] -class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]): +class AWSBatchScheduler( + DockerWorkspaceMixin, Scheduler[AWSBatchOpts, AppDef, AppDryRunInfo[BatchJob]] +): """ AWSBatchScheduler is a TorchX scheduling interface to AWS Batch. diff --git a/torchx/schedulers/aws_sagemaker_scheduler.py b/torchx/schedulers/aws_sagemaker_scheduler.py index 1b6e0cbe1..a67509520 100644 --- a/torchx/schedulers/aws_sagemaker_scheduler.py +++ b/torchx/schedulers/aws_sagemaker_scheduler.py @@ -156,7 +156,10 @@ def _merge_ordered( return merged -class AWSSageMakerScheduler(DockerWorkspaceMixin, Scheduler[AWSSageMakerOpts]): # type: ignore[misc] +class AWSSageMakerScheduler( + DockerWorkspaceMixin, + Scheduler[AWSSageMakerOpts, AppDef, AppDryRunInfo[AWSSageMakerJob]], +): """ AWSSageMakerScheduler is a TorchX scheduling interface to AWS SageMaker. diff --git a/torchx/schedulers/docker_scheduler.py b/torchx/schedulers/docker_scheduler.py index 59e524e65..454f43f92 100644 --- a/torchx/schedulers/docker_scheduler.py +++ b/torchx/schedulers/docker_scheduler.py @@ -128,7 +128,9 @@ class DockerOpts(TypedDict, total=False): privileged: bool -class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]): +class DockerScheduler( + DockerWorkspaceMixin, Scheduler[DockerOpts, AppDef, AppDryRunInfo[DockerJob]] +): """ DockerScheduler is a TorchX scheduling interface to Docker. diff --git a/torchx/schedulers/gcp_batch_scheduler.py b/torchx/schedulers/gcp_batch_scheduler.py index f4d0ef09c..a8fdc99f9 100644 --- a/torchx/schedulers/gcp_batch_scheduler.py +++ b/torchx/schedulers/gcp_batch_scheduler.py @@ -104,7 +104,7 @@ class GCPBatchOpts(TypedDict, total=False): location: Optional[str] -class GCPBatchScheduler(Scheduler[GCPBatchOpts]): +class GCPBatchScheduler(Scheduler[GCPBatchOpts, AppDef, AppDryRunInfo[GCPBatchJob]]): """ GCPBatchScheduler is a TorchX scheduling interface to GCP Batch. diff --git a/torchx/schedulers/kubernetes_mcad_scheduler.py b/torchx/schedulers/kubernetes_mcad_scheduler.py index 467c6363e..53f1b5deb 100644 --- a/torchx/schedulers/kubernetes_mcad_scheduler.py +++ b/torchx/schedulers/kubernetes_mcad_scheduler.py @@ -796,7 +796,9 @@ class KubernetesMCADOpts(TypedDict, total=False): network: Optional[str] -class KubernetesMCADScheduler(DockerWorkspaceMixin, Scheduler[KubernetesMCADOpts]): +class KubernetesMCADScheduler( + DockerWorkspaceMixin, Scheduler[KubernetesMCADOpts, AppDef, AppDryRunInfo] +): """ KubernetesMCADScheduler is a TorchX scheduling interface to Kubernetes. diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index 97d57b8cc..699e0d500 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -472,7 +472,10 @@ class KubernetesOpts(TypedDict, total=False): priority_class: Optional[str] -class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]): +class KubernetesScheduler( + DockerWorkspaceMixin, + Scheduler[KubernetesOpts, AppDef, AppDryRunInfo[KubernetesJob]], +): """ KubernetesScheduler is a TorchX scheduling interface to Kubernetes. diff --git a/torchx/schedulers/local_scheduler.py b/torchx/schedulers/local_scheduler.py index 47a487dd6..9250ee72a 100644 --- a/torchx/schedulers/local_scheduler.py +++ b/torchx/schedulers/local_scheduler.py @@ -529,7 +529,7 @@ def _register_termination_signals() -> None: signal.signal(signal.SIGINT, _terminate_process_handler) -class LocalScheduler(Scheduler[LocalOpts]): +class LocalScheduler(Scheduler[LocalOpts, AppDef, AppDryRunInfo[PopenRequest]]): """ Schedules on localhost. Containers are modeled as processes and certain properties of the container that are either not relevant diff --git a/torchx/schedulers/lsf_scheduler.py b/torchx/schedulers/lsf_scheduler.py index fdc915431..0ff8b905c 100644 --- a/torchx/schedulers/lsf_scheduler.py +++ b/torchx/schedulers/lsf_scheduler.py @@ -395,7 +395,7 @@ def __repr__(self) -> str: {self.materialize()}""" -class LsfScheduler(Scheduler[LsfOpts]): +class LsfScheduler(Scheduler[LsfOpts, AppDef, AppDryRunInfo]): """ **Example: hello_world** diff --git a/torchx/schedulers/ray_scheduler.py b/torchx/schedulers/ray_scheduler.py index 47767a406..af7e1be76 100644 --- a/torchx/schedulers/ray_scheduler.py +++ b/torchx/schedulers/ray_scheduler.py @@ -114,7 +114,9 @@ class RayJob: requirements: Optional[str] = None actors: List[RayActor] = field(default_factory=list) - class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]): + class RayScheduler( + TmpDirWorkspaceMixin, Scheduler[RayOpts, AppDef, AppDryRunInfo[RayJob]] + ): """ RayScheduler is a TorchX scheduling interface to Ray. The job def workers will be launched as Ray actors diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index e89b2b063..fb9c76982 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -259,7 +259,9 @@ def __repr__(self) -> str: {self.materialize()}""" -class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]): +class SlurmScheduler( + DirWorkspaceMixin, Scheduler[SlurmOpts, AppDef, AppDryRunInfo[SlurmBatchRequest]] +): """ SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects that slurm CLI tools are locally installed and job accounting is enabled. diff --git a/torchx/schedulers/test/api_test.py b/torchx/schedulers/test/api_test.py index c45767c56..1f65dd6b5 100644 --- a/torchx/schedulers/test/api_test.py +++ b/torchx/schedulers/test/api_test.py @@ -35,10 +35,12 @@ from torchx.workspace.api import WorkspaceMixin T = TypeVar("T") +A = TypeVar("A") +D = TypeVar("D") class SchedulerTest(unittest.TestCase): - class MockScheduler(Scheduler[T], WorkspaceMixin[None]): + class MockScheduler(Scheduler[T, A, D], WorkspaceMixin[None]): def __init__(self, session_name: str) -> None: super().__init__("mock", session_name) @@ -151,7 +153,7 @@ def test_invalid_dryrun_cfg(self) -> None: def test_role_preproc_called(self) -> None: scheduler_mock = SchedulerTest.MockScheduler("test_session") - app_mock = MagicMock() + app_mock = AppDef(name="test") app_mock.roles = [MagicMock()] cfg = {"foo": "bar"} @@ -161,7 +163,7 @@ def test_role_preproc_called(self) -> None: def test_validate(self) -> None: scheduler_mock = SchedulerTest.MockScheduler("test_session") - app_mock = MagicMock() + app_mock = AppDef(name="test") app_mock.roles = [MagicMock()] app_mock.roles[0].resource = NULL_RESOURCE diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index 13bc2f126..f8773d081 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -157,6 +157,7 @@ def test_submit_dryrun_tags(self, _) -> None: def test_submit_dryrun_job_role_arn(self) -> None: cfg = AWSBatchOpts({"queue": "ignored_in_test", "job_role_arn": "fizzbuzz"}) info = create_scheduler("test").submit_dryrun(_test_app(), cfg) + # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertEqual(cfg["job_role_arn"], node_groups[0]["container"]["jobRoleArn"]) @@ -166,6 +167,7 @@ def test_submit_dryrun_execution_role_arn(self) -> None: {"queue": "ignored_in_test", "execution_role_arn": "veryexecutive"} ) info = create_scheduler("test").submit_dryrun(_test_app(), cfg) + # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertEqual( @@ -175,6 +177,7 @@ def test_submit_dryrun_execution_role_arn(self) -> None: def test_submit_dryrun_privileged(self) -> None: cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True}) info = create_scheduler("test").submit_dryrun(_test_app(), cfg) + # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertTrue(node_groups[0]["container"]["privileged"]) @@ -184,6 +187,7 @@ def test_submit_dryrun_instance_type_multinode(self) -> None: resource = specs.named_resources_aws.aws_p3dn_24xlarge() app = _test_app(num_replicas=2, resource=resource) info = create_scheduler("test").submit_dryrun(app, cfg) + # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertEqual( @@ -196,6 +200,7 @@ def test_submit_dryrun_no_instance_type_singlenode(self) -> None: resource = specs.named_resources_aws.aws_p3dn_24xlarge() app = _test_app(num_replicas=1, resource=resource) info = create_scheduler("test").submit_dryrun(app, cfg) + # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertTrue("instanceType" not in node_groups[0]["container"]) @@ -205,6 +210,7 @@ def test_submit_dryrun_no_instance_type_non_aws(self) -> None: resource = specs.named_resources_aws.aws_p3dn_24xlarge() app = _test_app(num_replicas=2) info = create_scheduler("test").submit_dryrun(app, cfg) + # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertTrue("instanceType" not in node_groups[0]["container"]) diff --git a/torchx/schedulers/test/kubernetes_scheduler_test.py b/torchx/schedulers/test/kubernetes_scheduler_test.py index 6adc5b51b..9492ca962 100644 --- a/torchx/schedulers/test/kubernetes_scheduler_test.py +++ b/torchx/schedulers/test/kubernetes_scheduler_test.py @@ -515,7 +515,7 @@ def test_rank0_env(self) -> None: make_unique_ctx.return_value = "app-name-42" info = scheduler.submit_dryrun(app, cfg) - tasks = info.request.resource["spec"]["tasks"] + tasks = info.request.resource["spec"]["tasks"] # pyre-ignore[16] container0 = tasks[0]["template"].spec.containers[0] self.assertIn("TORCHX_RANK0_HOST", container0.command) self.assertIn(