@@ -293,6 +293,7 @@ class ResourceRequest:
293293 env_vars : dict [str , str ] = field (default_factory = dict )
294294 srun_args : Optional [list [str ]] = None
295295 job_details : SlurmJobDetails = field (default_factory = SlurmJobDetails )
296+ het_group_index : Optional [int ] = None
296297
297298 account : str
298299 partition : Optional [str ] = None
@@ -334,6 +335,7 @@ class ResourceRequest:
334335 monitor_group_job : bool = True
335336 monitor_group_job_wait_time : int = 60
336337 setup_lines : Optional [str ] = None
338+ het_group_indices : Optional [list [int ]] = None
337339
338340 #: Set by the executor; cannot be initialized
339341 job_name : str = field (init = False , default = "nemo-job" )
@@ -355,6 +357,21 @@ def merge(
355357
356358 main_executor = executors [0 ]
357359 main_executor .run_as_group = True
360+
361+ if main_executor .het_group_indices :
362+ assert (
363+ main_executor .heterogeneous
364+ ), "heterogeneous must be True if het_group_indices is provided"
365+ assert (
366+ len (main_executor .het_group_indices ) == num_tasks
367+ ), "het_group_indices must be the same length as the number of tasks"
368+ assert all (
369+ x <= y
370+ for x , y in zip (
371+ main_executor .het_group_indices , main_executor .het_group_indices [1 :]
372+ )
373+ ), "het_group_indices must be equal or increasing than previous"
374+
358375 main_executor .resource_group = [
359376 cls .ResourceRequest (
360377 packager = copy .deepcopy (main_executor .packager ),
@@ -367,10 +384,13 @@ def merge(
367384 gpus_per_task = main_executor .gpus_per_task ,
368385 srun_args = main_executor .srun_args ,
369386 job_details = copy .deepcopy (main_executor .job_details ),
387+ het_group_index = main_executor .het_group_indices [0 ]
388+ if main_executor .het_group_indices
389+ else None ,
370390 )
371391 ]
372392
373- for executor in executors [1 :]:
393+ for i , executor in enumerate ( executors [1 :]) :
374394 main_executor .resource_group .append (
375395 cls .ResourceRequest (
376396 packager = copy .deepcopy (executor .packager ),
@@ -383,6 +403,9 @@ def merge(
383403 gpus_per_task = executor .gpus_per_task ,
384404 srun_args = executor .srun_args ,
385405 job_details = copy .deepcopy (executor .job_details ),
406+ het_group_index = main_executor .het_group_indices [i + 1 ]
407+ if main_executor .het_group_indices
408+ else None ,
386409 )
387410 )
388411
@@ -803,8 +826,25 @@ def materialize(self) -> str:
803826 sbatch_flags = []
804827 if self .slurm_config .heterogeneous :
805828 assert len (self .jobs ) == len (self .slurm_config .resource_group )
829+ final_group_index = len (self .slurm_config .resource_group ) - 1
830+ if self .slurm_config .het_group_indices :
831+ final_group_index = self .slurm_config .het_group_indices .index (
832+ max (self .slurm_config .het_group_indices )
833+ )
834+
806835 for i in range (len (self .slurm_config .resource_group )):
807836 resource_req = self .slurm_config .resource_group [i ]
837+ if resource_req .het_group_index :
838+ assert (
839+ self .slurm_config .resource_group [i - 1 ].het_group_index is not None
840+ ), "het_group_index must be set for all requests in resource_group"
841+ if (
842+ i > 0
843+ and resource_req .het_group_index
844+ == self .slurm_config .resource_group [i - 1 ].het_group_index
845+ ):
846+ continue
847+
808848 het_parameters = parameters .copy ()
809849 het_parameters ["output" ] = parameters ["output" ].replace (
810850 original_job_name , self .jobs [i ]
@@ -824,7 +864,7 @@ def materialize(self) -> str:
824864 )
825865 for k in sorted (parameters ):
826866 sbatch_flags .append (_as_sbatch_flag (k , het_parameters [k ]))
827- if i != len ( self . slurm_config . resource_group ) - 1 :
867+ if i != final_group_index :
828868 sbatch_flags .append ("#SBATCH hetjob" )
829869 else :
830870 for k in sorted (parameters ):
@@ -934,7 +974,12 @@ def get_container_flags(
934974 _srun_args .extend (self .slurm_config .srun_args or [])
935975
936976 if self .slurm_config .run_as_group and self .slurm_config .heterogeneous :
937- het_group_flag = [f"--het-group={ group_ind } " ]
977+ het_group_index = (
978+ self .slurm_config .resource_group [group_ind ].het_group_index
979+ if self .slurm_config .resource_group [group_ind ].het_group_index is not None
980+ else group_ind
981+ )
982+ het_group_flag = [f"--het-group={ het_group_index } " ]
938983 else :
939984 het_group_flag = []
940985
0 commit comments