@@ -72,6 +72,45 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
7272 return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
7373
7474
75+ def _parse_slurm_version (version_str : str ) -> Tuple [int , int ]:
76+ """
77+ Parse Slurm version string (e.g., '24.05', '25.11.2') into (major, minor) tuple.
78+ Raises ValueError if parsing fails.
79+ """
80+ parts = version_str .split ("." )
81+ if len (parts ) < 2 :
82+ raise ValueError (
83+ f"Invalid Slurm version string: { version_str } . Expected format: '24.05' or '25.11.2'"
84+ )
85+
86+ try :
87+ major = int (parts [0 ])
88+ minor = int (parts [1 ])
89+ except (ValueError , IndexError ) as err :
90+ raise ValueError (
91+ f"Invalid Slurm version string: { version_str } . Expected format: '24.05' or '25.11.2'"
92+ ) from err
93+
94+ return (major , minor )
95+
96+
97+ def _should_use_gpus_per_node_from_version (version_str : Optional [str ]) -> bool :
98+ """
99+ Determine whether to use gpus-per-node based on version string.
100+ Returns True if version > 24.11, False otherwise or if version cannot be parsed.
101+ """
102+ if not version_str :
103+ return False
104+
105+ try :
106+ major , minor = _parse_slurm_version (version_str )
107+ except ValueError :
108+ return False
109+
110+ # Use gpus-per-node if version > 24.11
111+ return major > 24 or (major == 24 and minor > 11 )
112+
113+
75114SBATCH_JOB_OPTIONS = {
76115 "comment" ,
77116 "mail-user" ,
@@ -81,6 +120,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
81120 "partition" ,
82121 "time" ,
83122 "constraint" ,
123+ "qos" ,
84124}
85125
86126log : logging .Logger = logging .getLogger (__name__ )
@@ -106,6 +146,8 @@ def _apply_app_id_env(s: str) -> str:
106146 "mail-user" : Optional [str ],
107147 "mail-type" : Optional [str ],
108148 "job_dir" : Optional [str ],
149+ "qos" : Optional [str ],
150+ "slurm_version" : Optional [str ],
109151 },
110152 total = False ,
111153)
@@ -126,7 +168,11 @@ class SlurmReplicaRequest:
126168
127169 @classmethod
128170 def from_role (
129- cls , name : str , role : Role , cfg : SlurmOpts , nomem : bool
171+ cls ,
172+ name : str ,
173+ role : Role ,
174+ cfg : SlurmOpts ,
175+ nomem : bool ,
130176 ) -> "SlurmReplicaRequest" :
131177 """
132178 ``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -149,7 +195,12 @@ def from_role(
149195 if not nomem and resource .memMB > 0 :
150196 sbatch_opts .setdefault ("mem" , str (resource .memMB ))
151197 if resource .gpu > 0 :
152- sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
198+ # Use smart GPU allocation based on Slurm version from config
199+ slurm_version = cfg .get ("slurm_version" )
200+ if _should_use_gpus_per_node_from_version (slurm_version ):
201+ sbatch_opts .setdefault ("gpus-per-node" , str (resource .gpu ))
202+ else :
203+ sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
153204
154205 srun_opts = {
155206 "output" : f"slurm-{ macros .app_id } -{ name } .out" ,
@@ -378,6 +429,18 @@ def _run_opts(self) -> runopts:
378429 iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
379430 """ ,
380431 )
432+ opts .add (
433+ "qos" ,
434+ type_ = str ,
435+ help = "Quality of Service (QoS) to assign to the job." ,
436+ )
437+ opts .add (
438+ "slurm_version" ,
439+ type_ = str ,
440+ help = """Slurm version (e.g., '24.05', '25.11'). If version > 24.11,
441+ uses gpus-per-node instead of gpus-per-task for GPU allocation.
442+ """ ,
443+ )
381444 return opts
382445
383446 def schedule (self , dryrun_info : AppDryRunInfo [SlurmBatchRequest ]) -> str :
@@ -401,6 +464,37 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
401464
402465 return job_id
403466
467+ def _get_slurm_version (self ) -> str :
468+ """
469+ _get_slurm_version returns the Slurm version string (e.g., "24.05").
470+ Raises ValueError if version cannot be determined.
471+ """
472+ try :
473+ p = subprocess .run (
474+ ["sinfo" , "--version" ],
475+ stdout = subprocess .PIPE ,
476+ stderr = subprocess .PIPE ,
477+ )
478+ except FileNotFoundError as e :
479+ raise ValueError ("Slurm is not available (sinfo command not found)" ) from e
480+
481+ if p .returncode != 0 :
482+ raise ValueError (
483+ f"Failed to get Slurm version: { p .stderr .decode ('utf-8' ).strip ()} "
484+ )
485+
486+ output = p .stdout .decode ("utf-8" ).strip ().lower ()
487+ if not output .startswith ("slurm " ):
488+ raise ValueError (f"Unexpected sinfo --version output format: { output } " )
489+
490+ # Remove "slurm " prefix and extract version (e.g., "24.05.4" -> "24.05")
491+ version_full = output .replace ("slurm" , "" ).strip ()
492+ version_parts = version_full .split ("." )
493+ if len (version_parts ) < 2 :
494+ raise ValueError (f"Invalid Slurm version format: { version_full } " )
495+
496+ return f"{ version_parts [0 ]} .{ version_parts [1 ]} "
497+
404498 def _partition_memmb (self , partition : Optional [str ]) -> Optional [int ]:
405499 """
406500 _partition_memmb returns the memory allocation for the given partition
@@ -441,6 +535,12 @@ def _submit_dryrun(
441535 partition = cfg .get ("partition" )
442536 assert partition is None or isinstance (partition , str ), "partition must be str"
443537
538+ # Create a new config with the resolved slurm version
539+ resolved_cfg = cfg .copy ()
540+ resolved_cfg ["slurm_version" ] = cfg .get (
541+ "slurm_version" , self ._get_slurm_version ()
542+ )
543+
444544 # check if the partition has at least 1GB memory, if we're not sure,
445545 # default to using memory allocations
446546 memmb = self ._partition_memmb (partition )
@@ -460,7 +560,7 @@ def _submit_dryrun(
460560 replicas [name ] = SlurmReplicaRequest .from_role (
461561 name ,
462562 replica_role ,
463- cfg ,
563+ resolved_cfg ,
464564 nomem = nomem ,
465565 )
466566 cmd = ["sbatch" , "--parsable" ]
0 commit comments