diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index 7d1e1833d..f74dd2448 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -18,6 +18,7 @@ import shlex import subprocess import tempfile +import warnings from dataclasses import dataclass from datetime import datetime from subprocess import CalledProcessError, PIPE @@ -72,6 +73,55 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState: return SLURM_STATES.get(slurm_state, AppState.UNKNOWN) +def version() -> Tuple[int, int]: + """ + Uses ``sinfo --version`` to get the slurm version. If the command fails, it + assumes the version is ``slurm 24.05.8``. + + Returns: + ------- + Tuple[int, int] slurm version as a tuple of ints (major, minor). + """ + + cmd = ["sinfo", "--version"] + try: + out = subprocess.check_output(cmd, stderr=PIPE, encoding="utf-8") + except (CalledProcessError, FileNotFoundError): + out = "slurm 24.05.8" + warnings.warn( + "Error running: `{sinfo_cmd}` to get SLURM version. Are you running outside the " + "cluster's login or head node? This typically happens when running in `--dryrun`" + " mode. Assuming version is `slurm 24.05.8`.", + RuntimeWarning, + stacklevel=2, + ) + + # sinfo --version returns in the form "slurm 24.1.0" + _, version_literal = out.split(" ", maxsplit=2) + major, minor = [int(v) for v in version_literal.split(".")][:2] + + return (major, minor) + + +def _should_use_gpus_per_node_from_version() -> bool: + """ + Determine whether to use gpus-per-node based on automatically detected slurm version. + + Change Reference: https://fburl.com/sqwqzxn6 + > select/linear - Reject jobs asking for GRES per job|socket|task or cpus|mem per GRES. + + Returns: + ``True`` in slurm ``version>=24.11.0``, ``False`` otherwise. + """ + + slurm_24_11_0 = (24, 11) + slurm_version = version() + + return slurm_version[0] > slurm_24_11_0[0] or ( # Major version is greater + slurm_version[0] == slurm_24_11_0[0] and slurm_version[1] >= slurm_24_11_0[1] + ) # Major version is equal and minor version is greater or equal + + SBATCH_JOB_OPTIONS = { "comment", "mail-user", @@ -81,6 +131,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState: "partition", "time", "constraint", + "qos", } log: logging.Logger = logging.getLogger(__name__) @@ -106,6 +157,7 @@ def _apply_app_id_env(s: str) -> str: "mail-user": Optional[str], "mail-type": Optional[str], "job_dir": Optional[str], + "qos": Optional[str], }, total=False, ) @@ -126,7 +178,11 @@ class SlurmReplicaRequest: @classmethod def from_role( - cls, name: str, role: Role, cfg: SlurmOpts, nomem: bool + cls, + name: str, + role: Role, + cfg: SlurmOpts, + nomem: bool, ) -> "SlurmReplicaRequest": """ ``from_role`` creates a SlurmReplicaRequest for the specific role and @@ -149,7 +205,11 @@ def from_role( if not nomem and resource.memMB > 0: sbatch_opts.setdefault("mem", str(resource.memMB)) if resource.gpu > 0: - sbatch_opts.setdefault("gpus-per-task", str(resource.gpu)) + # Use smart GPU allocation based on automatically detected Slurm version + if _should_use_gpus_per_node_from_version(): + sbatch_opts.setdefault("gpus-per-node", str(resource.gpu)) + else: + sbatch_opts.setdefault("gpus-per-task", str(resource.gpu)) srun_opts = { "output": f"slurm-{macros.app_id}-{name}.out", @@ -378,6 +438,11 @@ def _run_opts(self) -> runopts: iteration, jobs will be tracked in ``.torchxslurmjobdirs``. """, ) + opts.add( + "qos", + type_=str, + help="Quality of Service (QoS) to assign to the job.", + ) return opts def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: diff --git a/torchx/schedulers/test/slurm_scheduler_test.py b/torchx/schedulers/test/slurm_scheduler_test.py index 23be9d674..480f02bc8 100644 --- a/torchx/schedulers/test/slurm_scheduler_test.py +++ b/torchx/schedulers/test/slurm_scheduler_test.py @@ -30,6 +30,11 @@ ) from torchx.specs import AppState +# Constants for version mocking to help with Pyre type inference +SLURM_VERSION_24_5 = (24, 5) +SLURM_VERSION_25_0 = (25, 0) + + DESCRIBE_SQUEUE = "torchx.schedulers.slurm_scheduler.SlurmScheduler._describe_squeue" DESCRIBE_SACCT = "torchx.schedulers.slurm_scheduler.SlurmScheduler._describe_sacct" @@ -105,7 +110,11 @@ def test_create_scheduler(self) -> None: scheduler = create_scheduler("foo") self.assertIsInstance(scheduler, SlurmScheduler) - def test_replica_request(self) -> None: + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_replica_request(self, mock_version: MagicMock) -> None: role = simple_role() sbatch, srun = SlurmReplicaRequest.from_role( "role-0", role, cfg={}, nomem=False @@ -135,7 +144,11 @@ def test_replica_request(self) -> None: ], ) - def test_replica_request_nomem(self) -> None: + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_replica_request_nomem(self, mock_version: MagicMock) -> None: sbatch, srun = SlurmReplicaRequest.from_role( "role-name", simple_role(), @@ -153,7 +166,11 @@ def test_replica_request_nomem(self) -> None: ], ) - def test_replica_request_constraint(self) -> None: + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_replica_request_constraint(self, mock_version: MagicMock) -> None: sbatch, srun = SlurmReplicaRequest.from_role( "role-name", simple_role(), @@ -208,7 +225,11 @@ def test_replica_request_run_config(self) -> None: sbatch, ) - def test_dryrun_multi_role(self) -> None: + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_dryrun_multi_role(self, mock_version: MagicMock) -> None: scheduler = create_scheduler("foo") app = simple_app() info = scheduler.submit_dryrun(app, cfg={}) @@ -262,8 +283,17 @@ def test_dryrun_multi_role(self) -> None: "torchx.schedulers.slurm_scheduler.SlurmScheduler._partition_memmb", return_value=2048, ) + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) @patch("subprocess.run") - def test_run_multi_role(self, run: MagicMock, partition_memmb: MagicMock) -> None: + def test_run_multi_role( + self, + run: MagicMock, + mock_version: MagicMock, + partition_memmb: MagicMock, + ) -> None: run.return_value.stdout = b"1234" scheduler = create_scheduler("foo") app = specs.AppDef( @@ -422,9 +452,10 @@ def test_describe_sacct_running( self.assertEqual(out.state, specs.AppState.RUNNING) def test_describe_squeue(self) -> None: - with importlib.resources.path( - __package__, "slurm-squeue-output.json" - ) as path, open(path) as fp: + with ( + importlib.resources.path(__package__, "slurm-squeue-output.json") as path, + open(path) as fp, + ): mock_output = fp.read() with patch("subprocess.check_output", return_value=mock_output): @@ -615,8 +646,12 @@ def test_log_iter(self, _1: MagicMock, _2: MagicMock) -> None: with self.assertRaises(ValueError): scheduler.log_iter("54", "echo", 1, streams=Stream.COMBINED) + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) @patch("subprocess.run") - def test_dryrun_nomem(self, run: MagicMock) -> None: + def test_dryrun_nomem(self, run: MagicMock, mock_version: MagicMock) -> None: run.return_value.returncode = 0 scheduler = create_scheduler("foo") @@ -634,7 +669,11 @@ def test_dryrun_nomem(self, run: MagicMock) -> None: info = scheduler.submit_dryrun(app, cfg={}) self.assertIn("mem", info.request.replicas["foo-0"].sbatch_opts) - def test_dryrun_comment(self) -> None: + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_dryrun_comment(self, mock_version: MagicMock) -> None: scheduler = create_scheduler("foo") app = simple_app() info = scheduler.submit_dryrun( @@ -648,7 +687,11 @@ def test_dryrun_comment(self) -> None: info.request.cmd, ) - def test_dryrun_mail(self) -> None: + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_dryrun_mail(self, mock_version: MagicMock) -> None: scheduler = create_scheduler("foo") app = simple_app() info = scheduler.submit_dryrun( @@ -671,9 +714,16 @@ def test_dryrun_mail(self) -> None: "torchx.schedulers.slurm_scheduler.SlurmScheduler._partition_memmb", return_value=2048, ) + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) @patch("subprocess.run") def test_run_workspace_job_dir( - self, run: MagicMock, partition_memmb: MagicMock + self, + run: MagicMock, + mock_version: MagicMock, + partition_memmb: MagicMock, ) -> None: with tmp_cwd(): run.return_value.stdout = b"1234" @@ -748,13 +798,21 @@ def _run_req( ) return os.WEXITSTATUS(os.system("bash test.sh")) - def test_sbatch(self) -> None: + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_sbatch(self, mock_version: MagicMock) -> None: scheduler = create_scheduler("foo") app = simple_app() info = scheduler.submit_dryrun(app, cfg={}) self.assertEqual(self._run_req(info.request, srun_exit=0, scontrol_exit=1), 0) - def test_sbatch_requeue(self) -> None: + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_sbatch_requeue(self, mock_version: MagicMock) -> None: scheduler = create_scheduler("foo") app = simple_app() info = scheduler.submit_dryrun(app, cfg={}) @@ -768,3 +826,223 @@ def test_sbatch_requeue(self) -> None: ) os.environ["SLURM_RESTART_COUNT"] = "3" self.assertEqual(self._run_req(info.request, srun_exit=1, scontrol_exit=123), 1) + + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_replica_request_qos(self, mock_version: MagicMock) -> None: + sbatch, srun = SlurmReplicaRequest.from_role( + "role-name", + simple_role(), + cfg={"qos": "high"}, + nomem=False, + ).materialize() + self.assertIn( + "--qos=high", + sbatch, + ) + + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_dryrun_qos(self, mock_version: MagicMock) -> None: + scheduler = create_scheduler("foo") + app = simple_app() + info = scheduler.submit_dryrun( + app, + cfg={ + "qos": "high", + }, + ) + # QoS should be in the sbatch options for each replica + for replica in info.request.replicas.values(): + self.assertIn("qos", replica.sbatch_opts) + self.assertEqual(replica.sbatch_opts["qos"], "high") + + def test_should_use_gpus_per_node_from_version(self) -> None: + from torchx.schedulers.slurm_scheduler import ( + _should_use_gpus_per_node_from_version, + ) + + # Test versions >= 24.11 (should use gpus-per-node) + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_25_0, + ): + self.assertTrue(_should_use_gpus_per_node_from_version()) + + slurm_version_24_12 = (24, 12) + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=slurm_version_24_12, + ): + self.assertTrue(_should_use_gpus_per_node_from_version()) + + slurm_version_25_11 = (25, 11) + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=slurm_version_25_11, + ): + self.assertTrue(_should_use_gpus_per_node_from_version()) + + slurm_version_24_11 = (24, 11) + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=slurm_version_24_11, + ): + self.assertTrue(_should_use_gpus_per_node_from_version()) + + # Test versions < 24.11 (should use gpus-per-task) + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ): + self.assertFalse(_should_use_gpus_per_node_from_version()) + + slurm_version_23_15 = (23, 15) + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=slurm_version_23_15, + ): + self.assertFalse(_should_use_gpus_per_node_from_version()) + + def test_smart_gpu_allocation_with_version_config(self) -> None: + role = simple_role() + + # Test gpus-per-node allocation (newer Slurm version) + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_25_0, + ): + sbatch, srun = SlurmReplicaRequest.from_role( + "role-name", + role, + cfg={}, + nomem=False, + ).materialize() + self.assertIn("--gpus-per-node=3", sbatch) + self.assertNotIn("--gpus-per-task=3", sbatch) + + # Test gpus-per-task allocation (older Slurm version) + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ): + sbatch, srun = SlurmReplicaRequest.from_role( + "role-name", + role, + cfg={}, + nomem=False, + ).materialize() + self.assertIn("--gpus-per-task=3", sbatch) + self.assertNotIn("--gpus-per-node=3", sbatch) + + def test_dryrun_smart_gpu_allocation_with_auto_detection(self) -> None: + scheduler = create_scheduler("foo") + app = mem_app() # This app has GPU resources + + # Test gpus-per-node allocation (newer Slurm version) + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_25_0, + ): + info = scheduler.submit_dryrun(app, cfg={}) + for replica in info.request.replicas.values(): + self.assertIn("gpus-per-node", replica.sbatch_opts) + self.assertNotIn("gpus-per-task", replica.sbatch_opts) + self.assertEqual(replica.sbatch_opts["gpus-per-node"], "3") + + # Test gpus-per-task allocation (older Slurm version) + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ): + info = scheduler.submit_dryrun(app, cfg={}) + for replica in info.request.replicas.values(): + self.assertIn("gpus-per-task", replica.sbatch_opts) + self.assertNotIn("gpus-per-node", replica.sbatch_opts) + self.assertEqual(replica.sbatch_opts["gpus-per-task"], "3") + + def test_qos_run_opts(self) -> None: + scheduler = create_scheduler("foo") + run_opts = scheduler.run_opts() + qos_opt = run_opts.get("qos") + self.assertIsNotNone(qos_opt) + self.assertEqual(qos_opt.opt_type, str) + self.assertIn("Quality of Service", qos_opt.help) + + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_replica_request_qos_and_constraint(self, mock_version: MagicMock) -> None: + # Test that QoS and constraint can be used together + sbatch, srun = SlurmReplicaRequest.from_role( + "role-name", + simple_role(), + cfg={"qos": "high", "constraint": "gpu"}, + nomem=False, + ).materialize() + self.assertIn("--qos=high", sbatch) + self.assertIn("--constraint=gpu", sbatch) + + @patch("subprocess.check_output") + def test_version(self, check_output: MagicMock) -> None: + from torchx.schedulers.slurm_scheduler import version + + # Test successful version parsing + check_output.return_value = "slurm 24.05.4" + ver = version() + self.assertEqual(ver, (24, 5)) + + # Test newer version + check_output.return_value = "slurm 25.11.2" + ver = version() + self.assertEqual(ver, (25, 11)) + + # Test command failure - should return the default slurm version 24.05.8 + check_output.side_effect = subprocess.CalledProcessError( + returncode=1, cmd=["sinfo", "--version"], stderr="Command failed" + ) + ver = version() + self.assertEqual(ver, (24, 5)) + + def test_no_gpu_resources(self) -> None: + # Test that GPU allocation logic doesn't interfere when no GPUs are requested + role = specs.Role( + name="no_gpu", + image="/some/path", + entrypoint="echo", + args=["hello"], + resource=specs.Resource(cpu=2, memMB=1024, gpu=0), # No GPUs + ) + + # Test with newer Slurm version - should not add any GPU options + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_25_0, + ): + sbatch, srun = SlurmReplicaRequest.from_role( + "role-name", + role, + cfg={}, + nomem=False, + ).materialize() + self.assertNotIn("--gpus-per-node", " ".join(sbatch)) + self.assertNotIn("--gpus-per-task", " ".join(sbatch)) + + # Test with older Slurm version - should not add any GPU options + with patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ): + sbatch, srun = SlurmReplicaRequest.from_role( + "role-name", + role, + cfg={}, + nomem=False, + ).materialize() + self.assertNotIn("--gpus-per-node", " ".join(sbatch)) + self.assertNotIn("--gpus-per-task", " ".join(sbatch))