Skip to content

Commit e324226

Browse files
Version based GPU configuration and QoS addition
Summary: Slurm 24.11.0rc1 and beyond do not suport GRES per task. So we need to call `gpus-per-node` in sbatch to ensure failure free allocation. https://github.com/SchedMD/slurm/blob/master/CHANGELOG/slurm-24.11.md # Changes here 1. Introduced Slurm Version based GPU request configuration 2. Introduced an option QoS parameter which can be used to control priority of jobs. Differential Revision: D78778304
1 parent 4adf7f6 commit e324226

File tree

2 files changed

+392
-11
lines changed

2 files changed

+392
-11
lines changed

torchx/schedulers/slurm_scheduler.py

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,45 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
7272
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
7373

7474

75+
def _parse_slurm_version(version_str: str) -> Tuple[int, int]:
76+
"""
77+
Parse Slurm version string (e.g., '24.05', '25.11.2') into (major, minor) tuple.
78+
Raises ValueError if parsing fails.
79+
"""
80+
parts = version_str.split(".")
81+
if len(parts) < 2:
82+
raise ValueError(
83+
f"Invalid Slurm version string: {version_str}. Expected format: '24.05' or '25.11.2'"
84+
)
85+
86+
try:
87+
major = int(parts[0])
88+
minor = int(parts[1])
89+
except (ValueError, IndexError) as err:
90+
raise ValueError(
91+
f"Invalid Slurm version string: {version_str}. Expected format: '24.05' or '25.11.2'"
92+
) from err
93+
94+
return (major, minor)
95+
96+
97+
def _should_use_gpus_per_node_from_version(version_str: Optional[str]) -> bool:
98+
"""
99+
Determine whether to use gpus-per-node based on version string.
100+
Returns True if version > 24.11, False otherwise or if version cannot be parsed.
101+
"""
102+
if not version_str:
103+
return False
104+
105+
try:
106+
major, minor = _parse_slurm_version(version_str)
107+
except ValueError:
108+
return False
109+
110+
# Use gpus-per-node if version > 24.11
111+
return major > 24 or (major == 24 and minor > 11)
112+
113+
75114
SBATCH_JOB_OPTIONS = {
76115
"comment",
77116
"mail-user",
@@ -81,6 +120,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
81120
"partition",
82121
"time",
83122
"constraint",
123+
"qos",
84124
}
85125

86126
log: logging.Logger = logging.getLogger(__name__)
@@ -106,6 +146,8 @@ def _apply_app_id_env(s: str) -> str:
106146
"mail-user": Optional[str],
107147
"mail-type": Optional[str],
108148
"job_dir": Optional[str],
149+
"qos": Optional[str],
150+
"slurm_version": Optional[str],
109151
},
110152
total=False,
111153
)
@@ -126,7 +168,11 @@ class SlurmReplicaRequest:
126168

127169
@classmethod
128170
def from_role(
129-
cls, name: str, role: Role, cfg: SlurmOpts, nomem: bool
171+
cls,
172+
name: str,
173+
role: Role,
174+
cfg: SlurmOpts,
175+
nomem: bool,
130176
) -> "SlurmReplicaRequest":
131177
"""
132178
``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -149,7 +195,12 @@ def from_role(
149195
if not nomem and resource.memMB > 0:
150196
sbatch_opts.setdefault("mem", str(resource.memMB))
151197
if resource.gpu > 0:
152-
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
198+
# Use smart GPU allocation based on Slurm version from config
199+
slurm_version = cfg.get("slurm_version")
200+
if _should_use_gpus_per_node_from_version(slurm_version):
201+
sbatch_opts.setdefault("gpus-per-node", str(resource.gpu))
202+
else:
203+
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
153204

154205
srun_opts = {
155206
"output": f"slurm-{macros.app_id}-{name}.out",
@@ -378,6 +429,18 @@ def _run_opts(self) -> runopts:
378429
iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
379430
""",
380431
)
432+
opts.add(
433+
"qos",
434+
type_=str,
435+
help="Quality of Service (QoS) to assign to the job.",
436+
)
437+
opts.add(
438+
"slurm_version",
439+
type_=str,
440+
help="""Slurm version (e.g., '24.05', '25.11'). If version > 24.11,
441+
uses gpus-per-node instead of gpus-per-task for GPU allocation.
442+
""",
443+
)
381444
return opts
382445

383446
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
@@ -401,6 +464,37 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
401464

402465
return job_id
403466

467+
def _get_slurm_version(self) -> str:
468+
"""
469+
_get_slurm_version returns the Slurm version string (e.g., "24.05").
470+
Raises ValueError if version cannot be determined.
471+
"""
472+
try:
473+
p = subprocess.run(
474+
["sinfo", "--version"],
475+
stdout=subprocess.PIPE,
476+
stderr=subprocess.PIPE,
477+
)
478+
except FileNotFoundError as e:
479+
raise ValueError("Slurm is not available (sinfo command not found)") from e
480+
481+
if p.returncode != 0:
482+
raise ValueError(
483+
f"Failed to get Slurm version: {p.stderr.decode('utf-8').strip()}"
484+
)
485+
486+
output = p.stdout.decode("utf-8").strip().lower()
487+
if not output.startswith("slurm "):
488+
raise ValueError(f"Unexpected sinfo --version output format: {output}")
489+
490+
# Remove "slurm " prefix and extract version (e.g., "24.05.4" -> "24.05")
491+
version_full = output.replace("slurm", "").strip()
492+
version_parts = version_full.split(".")
493+
if len(version_parts) < 2:
494+
raise ValueError(f"Invalid Slurm version format: {version_full}")
495+
496+
return f"{version_parts[0]}.{version_parts[1]}"
497+
404498
def _partition_memmb(self, partition: Optional[str]) -> Optional[int]:
405499
"""
406500
_partition_memmb returns the memory allocation for the given partition
@@ -441,6 +535,12 @@ def _submit_dryrun(
441535
partition = cfg.get("partition")
442536
assert partition is None or isinstance(partition, str), "partition must be str"
443537

538+
# Create a new config with the resolved slurm version
539+
resolved_cfg = cfg.copy()
540+
resolved_cfg["slurm_version"] = cfg.get(
541+
"slurm_version", self._get_slurm_version()
542+
)
543+
444544
# check if the partition has at least 1GB memory, if we're not sure,
445545
# default to using memory allocations
446546
memmb = self._partition_memmb(partition)
@@ -460,7 +560,7 @@ def _submit_dryrun(
460560
replicas[name] = SlurmReplicaRequest.from_role(
461561
name,
462562
replica_role,
463-
cfg,
563+
resolved_cfg,
464564
nomem=nomem,
465565
)
466566
cmd = ["sbatch", "--parsable"]

0 commit comments

Comments
 (0)