diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index 76e285539..eab0490cb 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -255,7 +255,7 @@ def _role_to_node_properties( container["jobRoleArn"] = job_role_arn if execution_role_arn: container["executionRoleArn"] = execution_role_arn - if role.num_replicas > 1: + if role.num_replicas > 0: instance_type = instance_type_from_resource(role.resource) if instance_type is not None: container["instanceType"] = instance_type diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index f8773d081..c2a5f65f6 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -195,7 +195,7 @@ def test_submit_dryrun_instance_type_multinode(self) -> None: node_groups[0]["container"]["instanceType"], ) - def test_submit_dryrun_no_instance_type_singlenode(self) -> None: + def test_submit_dryrun_instance_type_singlenode(self) -> None: cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True}) resource = specs.named_resources_aws.aws_p3dn_24xlarge() app = _test_app(num_replicas=1, resource=resource) @@ -203,7 +203,7 @@ def test_submit_dryrun_no_instance_type_singlenode(self) -> None: # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) - self.assertTrue("instanceType" not in node_groups[0]["container"]) + self.assertTrue("instanceType" in node_groups[0]["container"]) def test_submit_dryrun_no_instance_type_non_aws(self) -> None: cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})