Skip to content

Commit 5585ec6

Browse files
authored
Add support for heterogeneous job group indices in SlurmExecutor (#158)
* Add support for heterogeneous job group indices in SlurmExecutor - Introduce het_group_index parameter in ResourceRequest - Add het_group_indices parameter to SlurmExecutor - Implement validation and handling of heterogeneous job group indices - Modify Slurm batch request generation to use custom or default group indices Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]>
1 parent 98b0a2f commit 5585ec6

File tree

1 file changed

+48
-3
lines changed

1 file changed

+48
-3
lines changed

src/nemo_run/core/execution/slurm.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class ResourceRequest:
293293
env_vars: dict[str, str] = field(default_factory=dict)
294294
srun_args: Optional[list[str]] = None
295295
job_details: SlurmJobDetails = field(default_factory=SlurmJobDetails)
296+
het_group_index: Optional[int] = None
296297

297298
account: str
298299
partition: Optional[str] = None
@@ -334,6 +335,7 @@ class ResourceRequest:
334335
monitor_group_job: bool = True
335336
monitor_group_job_wait_time: int = 60
336337
setup_lines: Optional[str] = None
338+
het_group_indices: Optional[list[int]] = None
337339

338340
#: Set by the executor; cannot be initialized
339341
job_name: str = field(init=False, default="nemo-job")
@@ -355,6 +357,21 @@ def merge(
355357

356358
main_executor = executors[0]
357359
main_executor.run_as_group = True
360+
361+
if main_executor.het_group_indices:
362+
assert (
363+
main_executor.heterogeneous
364+
), "heterogeneous must be True if het_group_indices is provided"
365+
assert (
366+
len(main_executor.het_group_indices) == num_tasks
367+
), "het_group_indices must be the same length as the number of tasks"
368+
assert all(
369+
x <= y
370+
for x, y in zip(
371+
main_executor.het_group_indices, main_executor.het_group_indices[1:]
372+
)
373+
), "het_group_indices must be equal or increasing than previous"
374+
358375
main_executor.resource_group = [
359376
cls.ResourceRequest(
360377
packager=copy.deepcopy(main_executor.packager),
@@ -367,10 +384,13 @@ def merge(
367384
gpus_per_task=main_executor.gpus_per_task,
368385
srun_args=main_executor.srun_args,
369386
job_details=copy.deepcopy(main_executor.job_details),
387+
het_group_index=main_executor.het_group_indices[0]
388+
if main_executor.het_group_indices
389+
else None,
370390
)
371391
]
372392

373-
for executor in executors[1:]:
393+
for i, executor in enumerate(executors[1:]):
374394
main_executor.resource_group.append(
375395
cls.ResourceRequest(
376396
packager=copy.deepcopy(executor.packager),
@@ -383,6 +403,9 @@ def merge(
383403
gpus_per_task=executor.gpus_per_task,
384404
srun_args=executor.srun_args,
385405
job_details=copy.deepcopy(executor.job_details),
406+
het_group_index=main_executor.het_group_indices[i + 1]
407+
if main_executor.het_group_indices
408+
else None,
386409
)
387410
)
388411

@@ -803,8 +826,25 @@ def materialize(self) -> str:
803826
sbatch_flags = []
804827
if self.slurm_config.heterogeneous:
805828
assert len(self.jobs) == len(self.slurm_config.resource_group)
829+
final_group_index = len(self.slurm_config.resource_group) - 1
830+
if self.slurm_config.het_group_indices:
831+
final_group_index = self.slurm_config.het_group_indices.index(
832+
max(self.slurm_config.het_group_indices)
833+
)
834+
806835
for i in range(len(self.slurm_config.resource_group)):
807836
resource_req = self.slurm_config.resource_group[i]
837+
if resource_req.het_group_index:
838+
assert (
839+
self.slurm_config.resource_group[i - 1].het_group_index is not None
840+
), "het_group_index must be set for all requests in resource_group"
841+
if (
842+
i > 0
843+
and resource_req.het_group_index
844+
== self.slurm_config.resource_group[i - 1].het_group_index
845+
):
846+
continue
847+
808848
het_parameters = parameters.copy()
809849
het_parameters["output"] = parameters["output"].replace(
810850
original_job_name, self.jobs[i]
@@ -824,7 +864,7 @@ def materialize(self) -> str:
824864
)
825865
for k in sorted(parameters):
826866
sbatch_flags.append(_as_sbatch_flag(k, het_parameters[k]))
827-
if i != len(self.slurm_config.resource_group) - 1:
867+
if i != final_group_index:
828868
sbatch_flags.append("#SBATCH hetjob")
829869
else:
830870
for k in sorted(parameters):
@@ -934,7 +974,12 @@ def get_container_flags(
934974
_srun_args.extend(self.slurm_config.srun_args or [])
935975

936976
if self.slurm_config.run_as_group and self.slurm_config.heterogeneous:
937-
het_group_flag = [f"--het-group={group_ind}"]
977+
het_group_index = (
978+
self.slurm_config.resource_group[group_ind].het_group_index
979+
if self.slurm_config.resource_group[group_ind].het_group_index is not None
980+
else group_ind
981+
)
982+
het_group_flag = [f"--het-group={het_group_index}"]
938983
else:
939984
het_group_flag = []
940985

0 commit comments

Comments
 (0)