@@ -293,6 +293,7 @@ class ResourceRequest:
293293 env_vars : dict [str , str ] = field (default_factory = dict )
294294 srun_args : Optional [list [str ]] = None
295295 job_details : SlurmJobDetails = field (default_factory = SlurmJobDetails )
296+ het_group_index : Optional [int ] = None
296297
297298 account : str
298299 partition : Optional [str ] = None
@@ -334,6 +335,7 @@ class ResourceRequest:
334335 monitor_group_job : bool = True
335336 monitor_group_job_wait_time : int = 60
336337 setup_lines : Optional [str ] = None
338+ het_group_indices : Optional [list [int ]] = None
337339
338340 #: Set by the executor; cannot be initialized
339341 job_name : str = field (init = False , default = "nemo-job" )
@@ -355,6 +357,21 @@ def merge(
355357
356358 main_executor = executors [0 ]
357359 main_executor .run_as_group = True
360+
361+ if main_executor .het_group_indices :
362+ assert (
363+ main_executor .heterogeneous
364+ ), "heterogeneous must be True if het_group_indices is provided"
365+ assert (
366+ len (main_executor .het_group_indices ) == num_tasks
367+ ), "het_group_indices must be the same length as the number of tasks"
368+ assert all (
369+ x <= y
370+ for x , y in zip (
371+ main_executor .het_group_indices , main_executor .het_group_indices [1 :]
372+ )
373+ ), "het_group_indices must be equal or increasing than previous"
374+
358375 main_executor .resource_group = [
359376 cls .ResourceRequest (
360377 packager = copy .deepcopy (main_executor .packager ),
@@ -367,10 +384,13 @@ def merge(
367384 gpus_per_task = main_executor .gpus_per_task ,
368385 srun_args = main_executor .srun_args ,
369386 job_details = copy .deepcopy (main_executor .job_details ),
387+ het_group_index = main_executor .het_group_indices [0 ]
388+ if main_executor .het_group_indices
389+ else None ,
370390 )
371391 ]
372392
373- for executor in executors [1 :]:
393+ for i , executor in enumerate ( executors [1 :]) :
374394 main_executor .resource_group .append (
375395 cls .ResourceRequest (
376396 packager = copy .deepcopy (executor .packager ),
@@ -383,6 +403,9 @@ def merge(
383403 gpus_per_task = executor .gpus_per_task ,
384404 srun_args = executor .srun_args ,
385405 job_details = copy .deepcopy (executor .job_details ),
406+ het_group_index = main_executor .het_group_indices [i + 1 ]
407+ if main_executor .het_group_indices
408+ else None ,
386409 )
387410 )
388411
@@ -805,6 +828,17 @@ def materialize(self) -> str:
805828 assert len (self .jobs ) == len (self .slurm_config .resource_group )
806829 for i in range (len (self .slurm_config .resource_group )):
807830 resource_req = self .slurm_config .resource_group [i ]
831+ if resource_req .het_group_index :
832+ assert (
833+ self .slurm_config .resource_group [i - 1 ].het_group_index is not None
834+ ), "het_group_index must be set for all requests in resource_group"
835+ if (
836+ i > 0
837+ and resource_req .het_group_index
838+ == self .slurm_config .resource_group [i - 1 ].het_group_index
839+ ):
840+ continue
841+
808842 het_parameters = parameters .copy ()
809843 het_parameters ["output" ] = parameters ["output" ].replace (
810844 original_job_name , self .jobs [i ]
@@ -934,7 +968,12 @@ def get_container_flags(
934968 _srun_args .extend (self .slurm_config .srun_args or [])
935969
936970 if self .slurm_config .run_as_group and self .slurm_config .heterogeneous :
937- het_group_flag = [f"--het-group={ group_ind } " ]
971+ het_group_index = (
972+ self .slurm_config .resource_group [group_ind ].het_group_index
973+ if self .slurm_config .resource_group [group_ind ].het_group_index is not None
974+ else group_ind
975+ )
976+ het_group_flag = [f"--het-group={ het_group_index } " ]
938977 else :
939978 het_group_flag = []
940979
0 commit comments