Skip to content

Commit 38265c4

Browse files
authored
Add ability to customize job details between subtasks inside of a task group for slurm (#110)
1 parent 27bccfa commit 38265c4

File tree

2 files changed

+111
-95
lines changed

2 files changed

+111
-95
lines changed

src/nemo_run/core/execution/slurm.py

Lines changed: 77 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ class ResourceRequest:
286286
container_mounts: list[str] = field(default_factory=list)
287287
env_vars: dict[str, str] = field(default_factory=dict)
288288
srun_args: Optional[list[str]] = None
289+
job_details: SlurmJobDetails = field(default_factory=SlurmJobDetails)
289290

290291
account: str
291292
partition: Optional[str] = None
@@ -359,6 +360,7 @@ def merge(
359360
gpus_per_node=main_executor.gpus_per_node,
360361
gpus_per_task=main_executor.gpus_per_task,
361362
srun_args=main_executor.srun_args,
363+
job_details=copy.deepcopy(main_executor.job_details),
362364
)
363365
]
364366

@@ -374,6 +376,7 @@ def merge(
374376
gpus_per_node=executor.gpus_per_node,
375377
gpus_per_task=executor.gpus_per_task,
376378
srun_args=executor.srun_args,
379+
job_details=copy.deepcopy(executor.job_details),
377380
)
378381
)
379382

@@ -720,10 +723,12 @@ def materialize(self) -> str:
720723
)
721724
# add necessary parameters
722725
original_job_name: str = self.jobs[0] # type: ignore
723-
if self.slurm_config.job_name_prefix is None:
724-
job_name = f"{self.slurm_config.account}-{self.slurm_config.account.split('_')[-1]}.{original_job_name}"
725-
else:
726-
job_name = f"{self.slurm_config.job_name_prefix}{original_job_name}"
726+
job_name_prefix = (
727+
self.slurm_config.job_name_prefix
728+
if self.slurm_config.job_name_prefix
729+
else f"{self.slurm_config.account}-{self.slurm_config.account.split('_')[-1]}."
730+
)
731+
job_name = f"{job_name_prefix}{original_job_name}"
727732
slurm_job_dir = (
728733
self.slurm_config.tunnel.job_dir
729734
if self.slurm_config.tunnel
@@ -812,9 +817,9 @@ def materialize(self) -> str:
812817
if self.slurm_config.stderr_to_stdout
813818
else ["--error", noquote(job_details.srun_stderr)]
814819
)
815-
memory_measure = None
820+
memory_measure_out = None
816821
if self.slurm_config.memory_measure:
817-
memory_measure = srun_stdout
822+
memory_measure_out = srun_stdout
818823

819824
def get_container_flags(
820825
base_mounts: list[str], src_job_dir: str, container_image: Optional[str]
@@ -833,113 +838,90 @@ def get_container_flags(
833838
return _container_flags
834839

835840
for group_ind, command_group in enumerate(self.command_groups):
836-
if self.slurm_config.heterogeneous:
841+
if self.slurm_config.run_as_group and len(self.slurm_config.resource_group) == len(
842+
self.command_groups
843+
):
837844
resource_req = self.slurm_config.resource_group[group_ind]
845+
if not resource_req.job_details.job_name:
846+
resource_req.job_details.job_name = f"{job_name_prefix}{self.jobs[group_ind]}"
838847

848+
if not resource_req.job_details.folder:
849+
resource_req.job_details.folder = os.path.join(
850+
slurm_job_dir, job_directory_name
851+
)
852+
853+
cmd_stdout = noquote(resource_req.job_details.srun_stdout)
854+
cmd_stderr = (
855+
[]
856+
if self.slurm_config.stderr_to_stdout
857+
else [
858+
"--error",
859+
noquote(resource_req.job_details.srun_stderr),
860+
]
861+
)
839862
current_env_vars = []
840863
for key, value in resource_req.env_vars.items():
841864
current_env_vars.append(f"export {key.upper()}={value}")
842865

843866
group_env_vars.append(current_env_vars)
844867

845-
het_group = f"--het-group={group_ind}"
846-
het_stdout = srun_stdout.replace(original_job_name, self.jobs[group_ind])
847-
het_stderr = stderr_flags.copy()
848-
if het_stderr:
849-
het_stderr[-1] = het_stderr[-1].replace(original_job_name, self.jobs[group_ind])
850-
851-
_group_srun_args = ["--wait=60", "--kill-on-bad-exit=1"]
852-
_group_srun_args.extend(resource_req.srun_args or [])
853-
srun_cmd = " ".join(
854-
list(
855-
map(
856-
lambda arg: arg if isinstance(arg, noquote) else shlex.quote(arg),
857-
[
858-
"srun",
859-
het_group,
860-
"--output",
861-
het_stdout,
862-
*het_stderr,
863-
*get_container_flags(
864-
base_mounts=resource_req.container_mounts,
865-
src_job_dir=os.path.join(slurm_job_dir, job_directory_name),
866-
container_image=resource_req.container_image,
867-
),
868-
*_group_srun_args,
869-
],
870-
)
871-
)
868+
_container_flags = get_container_flags(
869+
base_mounts=resource_req.container_mounts,
870+
src_job_dir=os.path.join(
871+
slurm_job_dir,
872+
job_directory_name,
873+
),
874+
container_image=resource_req.container_image,
872875
)
873-
874-
command = ";\n ".join(command_group)
875-
876-
srun_command = f"{srun_cmd} {command} & pids[{group_ind}]=$!"
877-
if group_ind != len(self.slurm_config.resource_group) - 1:
878-
srun_command += f"\n\nsleep {self.slurm_config.wait_time_for_group_job}\n"
879-
srun_commands.append(srun_command)
876+
_srun_args = ["--wait=60", "--kill-on-bad-exit=1"]
877+
_srun_args.extend(resource_req.srun_args or [])
880878
else:
881879
cmd_stdout = srun_stdout.replace(original_job_name, self.jobs[group_ind])
882880
cmd_stderr = stderr_flags.copy()
883881
if cmd_stderr:
884882
cmd_stderr[-1] = cmd_stderr[-1].replace(original_job_name, self.jobs[group_ind])
883+
_container_flags = get_container_flags(
884+
base_mounts=self.slurm_config.container_mounts,
885+
src_job_dir=os.path.join(
886+
slurm_job_dir,
887+
job_directory_name,
888+
),
889+
container_image=self.slurm_config.container_image,
890+
)
891+
_srun_args = ["--wait=60", "--kill-on-bad-exit=1"]
892+
_srun_args.extend(self.slurm_config.srun_args or [])
885893

886-
if self.slurm_config.run_as_group and len(self.slurm_config.resource_group) == len(
887-
self.command_groups
888-
):
889-
resource_req = self.slurm_config.resource_group[group_ind]
890-
current_env_vars = []
891-
for key, value in resource_req.env_vars.items():
892-
current_env_vars.append(f"export {key.upper()}={value}")
893-
894-
group_env_vars.append(current_env_vars)
895-
896-
_container_flags = get_container_flags(
897-
base_mounts=resource_req.container_mounts,
898-
src_job_dir=os.path.join(
899-
slurm_job_dir,
900-
job_directory_name,
901-
),
902-
container_image=resource_req.container_image,
903-
)
904-
_srun_args = ["--wait=60", "--kill-on-bad-exit=1"]
905-
_srun_args.extend(resource_req.srun_args or [])
906-
else:
907-
_container_flags = get_container_flags(
908-
base_mounts=self.slurm_config.container_mounts,
909-
src_job_dir=os.path.join(
910-
slurm_job_dir,
911-
job_directory_name,
912-
),
913-
container_image=self.slurm_config.container_image,
914-
)
915-
_srun_args = ["--wait=60", "--kill-on-bad-exit=1"]
916-
_srun_args.extend(self.slurm_config.srun_args or [])
917-
918-
srun_cmd = " ".join(
919-
list(
920-
map(
921-
lambda arg: arg if isinstance(arg, noquote) else shlex.quote(arg),
922-
[
923-
"srun",
924-
"--output",
925-
cmd_stdout,
926-
*cmd_stderr,
927-
*_container_flags,
928-
*_srun_args,
929-
],
930-
)
894+
if self.slurm_config.run_as_group and self.slurm_config.heterogeneous:
895+
het_group_flag = [f"--het-group={group_ind}"]
896+
else:
897+
het_group_flag = []
898+
899+
srun_cmd = " ".join(
900+
list(
901+
map(
902+
lambda arg: arg if isinstance(arg, noquote) else shlex.quote(arg),
903+
[
904+
"srun",
905+
*het_group_flag,
906+
"--output",
907+
cmd_stdout,
908+
*cmd_stderr,
909+
*_container_flags,
910+
*_srun_args,
911+
],
931912
)
932913
)
933-
command = " ".join(command_group)
914+
)
915+
command = " ".join(command_group)
934916

935-
if self.slurm_config.run_as_group:
936-
srun_command = f"{srun_cmd} {command} & pids[{group_ind}]=$!"
937-
if group_ind != len(self.command_groups) - 1:
938-
srun_command += f"\n\nsleep {self.slurm_config.wait_time_for_group_job}\n"
939-
else:
940-
srun_command = f"{srun_cmd} {command}"
917+
if self.slurm_config.run_as_group:
918+
srun_command = f"{srun_cmd} {command} & pids[{group_ind}]=$!"
919+
if group_ind != len(self.command_groups) - 1:
920+
srun_command += f"\n\nsleep {self.slurm_config.wait_time_for_group_job}\n"
921+
else:
922+
srun_command = f"{srun_cmd} {command}"
941923

942-
srun_commands.append(srun_command)
924+
srun_commands.append(srun_command)
943925

944926
vars_to_fill = {
945927
"sbatch_command": sbatch_cmd,
@@ -948,7 +930,7 @@ def get_container_flags(
948930
"env_vars": env_vars,
949931
"head_node_ip_var": SlurmExecutor.HEAD_NODE_IP_VAR,
950932
"setup_lines": self.slurm_config.setup_lines,
951-
"memory_measure": memory_measure,
933+
"memory_measure": memory_measure_out,
952934
"srun_commands": srun_commands,
953935
"group_env_vars": group_env_vars,
954936
"heterogeneous": self.slurm_config.heterogeneous,

test/core/execution/test_slurm.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,40 @@ def test_group_resource_req_batch_request_materialize(
624624
expected = Path(artifact).read_text()
625625
assert sbatch_script.strip() == expected.strip()
626626

627+
def test_group_resource_req_request_custom_job_details(
628+
self,
629+
group_resource_req_slurm_request_with_artifact: tuple[SlurmBatchRequest, str],
630+
):
631+
class CustomJobDetails(SlurmJobDetails):
632+
@property
633+
def stdout(self) -> Path:
634+
assert self.folder
635+
return Path(self.folder / "sbatch_job.out")
636+
637+
@property
638+
def srun_stdout(self) -> Path:
639+
assert self.folder
640+
return Path(self.folder / f"log_{self.job_name}.out")
641+
642+
group_resource_req_slurm_request, _ = group_resource_req_slurm_request_with_artifact
643+
group_resource_req_slurm_request.slurm_config.job_details = CustomJobDetails(
644+
job_name="custom_sample_job", folder=Path("/custom_folder")
645+
)
646+
group_resource_req_slurm_request.slurm_config.resource_group[0].job_details = copy.deepcopy(
647+
group_resource_req_slurm_request.slurm_config.job_details
648+
)
649+
group_resource_req_slurm_request.slurm_config.resource_group[
650+
1
651+
].job_details = CustomJobDetails(
652+
job_name="custom_sample_job_2", folder=Path("/custom_folder_2")
653+
)
654+
655+
sbatch_script = group_resource_req_slurm_request.materialize()
656+
assert "#SBATCH --job-name=custom_sample_job" in sbatch_script
657+
assert "srun --output /custom_folder/log_custom_sample_job.out" in sbatch_script
658+
assert "srun --output /custom_folder_2/log_custom_sample_job_2.out" in sbatch_script
659+
assert "#SBATCH --output=/custom_folder/sbatch_job.out" in sbatch_script
660+
627661
def test_ft_slurm_request_materialize(
628662
self, ft_slurm_request_with_artifact: tuple[SlurmBatchRequest, str]
629663
):

0 commit comments

Comments
 (0)