Skip to content

Commit c152888

Browse files
committed
Add support for heterogeneous job group indices in SlurmExecutor
- Introduce het_group_index parameter in ResourceRequest - Add het_group_indices parameter to SlurmExecutor - Implement validation and handling of heterogeneous job group indices - Modify Slurm batch request generation to use custom or default group indices Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent 98b0a2f commit c152888

File tree

1 file changed

+41
-2
lines changed

1 file changed

+41
-2
lines changed

src/nemo_run/core/execution/slurm.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class ResourceRequest:
293293
env_vars: dict[str, str] = field(default_factory=dict)
294294
srun_args: Optional[list[str]] = None
295295
job_details: SlurmJobDetails = field(default_factory=SlurmJobDetails)
296+
het_group_index: Optional[int] = None
296297

297298
account: str
298299
partition: Optional[str] = None
@@ -334,6 +335,7 @@ class ResourceRequest:
334335
monitor_group_job: bool = True
335336
monitor_group_job_wait_time: int = 60
336337
setup_lines: Optional[str] = None
338+
het_group_indices: Optional[list[int]] = None
337339

338340
#: Set by the executor; cannot be initialized
339341
job_name: str = field(init=False, default="nemo-job")
@@ -355,6 +357,21 @@ def merge(
355357

356358
main_executor = executors[0]
357359
main_executor.run_as_group = True
360+
361+
if main_executor.het_group_indices:
362+
assert (
363+
main_executor.heterogeneous
364+
), "heterogeneous must be True if het_group_indices is provided"
365+
assert (
366+
len(main_executor.het_group_indices) == num_tasks
367+
), "het_group_indices must be the same length as the number of tasks"
368+
assert all(
369+
x <= y
370+
for x, y in zip(
371+
main_executor.het_group_indices, main_executor.het_group_indices[1:]
372+
)
373+
), "het_group_indices must be equal or increasing than previous"
374+
358375
main_executor.resource_group = [
359376
cls.ResourceRequest(
360377
packager=copy.deepcopy(main_executor.packager),
@@ -367,10 +384,13 @@ def merge(
367384
gpus_per_task=main_executor.gpus_per_task,
368385
srun_args=main_executor.srun_args,
369386
job_details=copy.deepcopy(main_executor.job_details),
387+
het_group_index=main_executor.het_group_indices[0]
388+
if main_executor.het_group_indices
389+
else None,
370390
)
371391
]
372392

373-
for executor in executors[1:]:
393+
for i, executor in enumerate(executors[1:]):
374394
main_executor.resource_group.append(
375395
cls.ResourceRequest(
376396
packager=copy.deepcopy(executor.packager),
@@ -383,6 +403,9 @@ def merge(
383403
gpus_per_task=executor.gpus_per_task,
384404
srun_args=executor.srun_args,
385405
job_details=copy.deepcopy(executor.job_details),
406+
het_group_index=main_executor.het_group_indices[i + 1]
407+
if main_executor.het_group_indices
408+
else None,
386409
)
387410
)
388411

@@ -805,6 +828,17 @@ def materialize(self) -> str:
805828
assert len(self.jobs) == len(self.slurm_config.resource_group)
806829
for i in range(len(self.slurm_config.resource_group)):
807830
resource_req = self.slurm_config.resource_group[i]
831+
if resource_req.het_group_index:
832+
assert (
833+
self.slurm_config.resource_group[i - 1].het_group_index is not None
834+
), "het_group_index must be set for all requests in resource_group"
835+
if (
836+
i > 0
837+
and resource_req.het_group_index
838+
== self.slurm_config.resource_group[i - 1].het_group_index
839+
):
840+
continue
841+
808842
het_parameters = parameters.copy()
809843
het_parameters["output"] = parameters["output"].replace(
810844
original_job_name, self.jobs[i]
@@ -934,7 +968,12 @@ def get_container_flags(
934968
_srun_args.extend(self.slurm_config.srun_args or [])
935969

936970
if self.slurm_config.run_as_group and self.slurm_config.heterogeneous:
937-
het_group_flag = [f"--het-group={group_ind}"]
971+
het_group_index = (
972+
self.slurm_config.resource_group[group_ind].het_group_index
973+
if self.slurm_config.resource_group[group_ind].het_group_index is not None
974+
else group_ind
975+
)
976+
het_group_flag = [f"--het-group={het_group_index}"]
938977
else:
939978
het_group_flag = []
940979

0 commit comments

Comments
 (0)