Skip to content

Commit 6b01546

Browse files
authored
Add option to specify --container-env for srun (#293)
Signed-off-by: Hemil Desai <[email protected]>
1 parent 2ca3b41 commit 6b01546

File tree

2 files changed

+187
-2
lines changed

2 files changed

+187
-2
lines changed

nemo_run/core/execution/slurm.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ class ResourceRequest:
294294
gpus_per_node: Optional[int] = None
295295
gpus_per_task: Optional[int] = None
296296
container_mounts: list[str] = field(default_factory=list)
297+
container_env: Optional[list[str]] = None
297298
env_vars: dict[str, str] = field(default_factory=dict)
298299
srun_args: Optional[list[str]] = None
299300
job_details: SlurmJobDetails = field(default_factory=SlurmJobDetails)
@@ -323,6 +324,7 @@ class ResourceRequest:
323324
open_mode: str = "append"
324325
container_image: Optional[str] = None
325326
container_mounts: list[str] = field(default_factory=list)
327+
container_env: Optional[list[str]] = None
326328
additional_parameters: Optional[dict[str, Any]] = None
327329
srun_args: Optional[list[str]] = None
328330
heterogeneous: bool = False
@@ -385,6 +387,7 @@ def merge(
385387
ntasks_per_node=main_executor.ntasks_per_node,
386388
container_image=copy.deepcopy(main_executor.container_image),
387389
container_mounts=copy.deepcopy(main_executor.container_mounts),
390+
container_env=copy.deepcopy(main_executor.container_env),
388391
env_vars=copy.deepcopy(main_executor.env_vars),
389392
gpus_per_node=main_executor.gpus_per_node,
390393
gpus_per_task=main_executor.gpus_per_task,
@@ -404,6 +407,7 @@ def merge(
404407
ntasks_per_node=executor.ntasks_per_node,
405408
container_image=copy.deepcopy(executor.container_image),
406409
container_mounts=copy.deepcopy(executor.container_mounts),
410+
container_env=copy.deepcopy(executor.container_env),
407411
env_vars=copy.deepcopy(executor.env_vars),
408412
gpus_per_node=executor.gpus_per_node,
409413
gpus_per_task=executor.gpus_per_task,
@@ -860,7 +864,7 @@ def materialize(self) -> str:
860864

861865
for i in range(len(self.executor.resource_group)):
862866
resource_req = self.executor.resource_group[i]
863-
if resource_req.het_group_index:
867+
if resource_req.het_group_index is not None:
864868
assert self.executor.resource_group[i - 1].het_group_index is not None, (
865869
"het_group_index must be set for all requests in resource_group"
866870
)
@@ -924,7 +928,10 @@ def materialize(self) -> str:
924928
memory_measure_out = srun_stdout
925929

926930
def get_container_flags(
927-
base_mounts: list[str], src_job_dir: str, container_image: Optional[str]
931+
base_mounts: list[str],
932+
src_job_dir: str,
933+
container_image: Optional[str],
934+
container_env: Optional[list[str]] = None,
928935
) -> list[str]:
929936
_container_flags = ["--container-image", container_image] if container_image else []
930937

@@ -940,6 +947,8 @@ def get_container_flags(
940947
"--container-workdir",
941948
f"/{RUNDIR_NAME}/code",
942949
]
950+
if container_env:
951+
_container_flags += ["--container-env", ",".join(container_env)]
943952

944953
return _container_flags
945954

@@ -978,6 +987,7 @@ def get_container_flags(
978987
job_directory_name,
979988
),
980989
container_image=resource_req.container_image,
990+
container_env=resource_req.container_env,
981991
)
982992
_srun_args = ["--wait=60", "--kill-on-bad-exit=1"]
983993
_srun_args.extend(resource_req.srun_args or [])
@@ -993,6 +1003,7 @@ def get_container_flags(
9931003
job_directory_name,
9941004
),
9951005
container_image=self.executor.container_image,
1006+
container_env=self.executor.container_env,
9961007
)
9971008
_srun_args = ["--wait=60", "--kill-on-bad-exit=1"]
9981009
_srun_args.extend(self.executor.srun_args or [])

test/core/execution/test_slurm_templates.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,3 +682,177 @@ def srun_stdout(self):
682682
sbatch_script = het_request.materialize()
683683
for i in range(len(het_request.jobs)):
684684
assert f"#SBATCH --job-name={custom_name}-{i}" in sbatch_script
685+
686+
def test_exclusive_parameter_boolean(self, dummy_slurm_request_with_artifact):
687+
dummy_slurm_request, _ = dummy_slurm_request_with_artifact
688+
689+
# Test exclusive=True
690+
dummy_slurm_request.executor.exclusive = True
691+
sbatch_script = dummy_slurm_request.materialize()
692+
assert "#SBATCH --exclusive" in sbatch_script
693+
694+
# Test exclusive=None (should not appear)
695+
dummy_slurm_request.executor.exclusive = None
696+
sbatch_script = dummy_slurm_request.materialize()
697+
assert "#SBATCH --exclusive" not in sbatch_script
698+
699+
def test_exclusive_parameter_string(self, dummy_slurm_request_with_artifact):
700+
dummy_slurm_request, _ = dummy_slurm_request_with_artifact
701+
702+
# Test exclusive="user"
703+
dummy_slurm_request.executor.exclusive = "user"
704+
sbatch_script = dummy_slurm_request.materialize()
705+
assert "#SBATCH --exclusive=user" in sbatch_script
706+
707+
def test_segment_parameter(self, dummy_slurm_request_with_artifact):
708+
dummy_slurm_request, _ = dummy_slurm_request_with_artifact
709+
dummy_slurm_request.executor.segment = 1
710+
sbatch_script = dummy_slurm_request.materialize()
711+
assert "#SBATCH --segment=1" in sbatch_script
712+
713+
def test_network_parameter(self, dummy_slurm_request_with_artifact):
714+
dummy_slurm_request, _ = dummy_slurm_request_with_artifact
715+
dummy_slurm_request.executor.network = "ib"
716+
sbatch_script = dummy_slurm_request.materialize()
717+
assert "#SBATCH --network=ib" in sbatch_script
718+
719+
def test_setup_lines_included(self, dummy_slurm_request_with_artifact):
720+
dummy_slurm_request, _ = dummy_slurm_request_with_artifact
721+
setup_commands = "module load cuda/12.0\nexport CUSTOM_VAR=value"
722+
dummy_slurm_request.executor.setup_lines = setup_commands
723+
sbatch_script = dummy_slurm_request.materialize()
724+
assert "module load cuda/12.0" in sbatch_script
725+
assert "export CUSTOM_VAR=value" in sbatch_script
726+
727+
def test_container_env_variables(self, dummy_slurm_request_with_artifact):
728+
dummy_slurm_request, _ = dummy_slurm_request_with_artifact
729+
dummy_slurm_request.executor.container_image = "test_image"
730+
dummy_slurm_request.executor.container_env = ["VAR1", "VAR2", "VAR3"]
731+
sbatch_script = dummy_slurm_request.materialize()
732+
assert "--container-env VAR1,VAR2,VAR3" in sbatch_script
733+
734+
def test_rundir_special_name_replacement(self, dummy_slurm_request_with_artifact):
735+
dummy_slurm_request, _ = dummy_slurm_request_with_artifact
736+
from nemo_run.config import RUNDIR_SPECIAL_NAME
737+
738+
dummy_slurm_request.executor.container_mounts = [
739+
f"{RUNDIR_SPECIAL_NAME}/data:/data",
740+
"/regular/mount:/mount",
741+
]
742+
dummy_slurm_request.executor.container_image = "test_image"
743+
sbatch_script = dummy_slurm_request.materialize()
744+
745+
# Should replace RUNDIR_SPECIAL_NAME with the actual job directory
746+
assert "/root/sample_job/data:/data" in sbatch_script
747+
assert "/regular/mount:/mount" in sbatch_script
748+
749+
def test_het_group_indices(self, het_slurm_request_with_artifact):
750+
het_slurm_request, _ = het_slurm_request_with_artifact
751+
752+
# Set custom het_group_indices
753+
het_slurm_request.executor.het_group_indices = [0, 0] # Both jobs in same het group
754+
het_slurm_request.executor.resource_group[0].het_group_index = 0
755+
het_slurm_request.executor.resource_group[1].het_group_index = 0
756+
757+
sbatch_script = het_slurm_request.materialize()
758+
759+
# Should have --het-group=0 for both commands
760+
assert "--het-group=0" in sbatch_script
761+
# Should only have one set of SBATCH flags since both are in same group
762+
assert sbatch_script.count("#SBATCH --account=your_account") == 1
763+
764+
def test_het_group_indices_multiple_groups(self, het_slurm_request_with_artifact):
765+
het_slurm_request, _ = het_slurm_request_with_artifact
766+
767+
# Add a third resource group
768+
het_slurm_request.executor.resource_group.append(
769+
SlurmExecutor.ResourceRequest(
770+
packager=GitArchivePackager(),
771+
nodes=2,
772+
ntasks_per_node=4,
773+
container_image="image_3",
774+
gpus_per_node=4,
775+
env_vars={"CUSTOM_ENV_3": "value3"},
776+
container_mounts=[],
777+
)
778+
)
779+
het_slurm_request.jobs.append("sample_job-2")
780+
het_slurm_request.command_groups.append(["bash ./scripts/third_job.sh"])
781+
782+
# Set het_group_indices: job 0 and 1 in group 0, job 2 in group 1
783+
het_slurm_request.executor.het_group_indices = [0, 0, 1]
784+
het_slurm_request.executor.resource_group[0].het_group_index = 0
785+
het_slurm_request.executor.resource_group[1].het_group_index = 0
786+
het_slurm_request.executor.resource_group[2].het_group_index = 1
787+
788+
sbatch_script = het_slurm_request.materialize()
789+
790+
# Should have two sets of SBATCH flags (one for each het group)
791+
assert sbatch_script.count("#SBATCH hetjob") == 1 # Only between different groups
792+
assert "--het-group=0" in sbatch_script
793+
assert "--het-group=1" in sbatch_script
794+
795+
def test_stderr_to_stdout_false(self, dummy_slurm_request_with_artifact):
796+
dummy_slurm_request, _ = dummy_slurm_request_with_artifact
797+
dummy_slurm_request.executor.stderr_to_stdout = False
798+
799+
sbatch_script = dummy_slurm_request.materialize()
800+
801+
# Should have separate error file
802+
assert "#SBATCH --error=" in sbatch_script
803+
assert (
804+
"--error /root/sample_job/log-account-account.sample_job_%j_${SLURM_RESTART_COUNT:-0}.err"
805+
in sbatch_script
806+
)
807+
808+
def test_wait_time_for_group_job_zero(self, group_slurm_request_with_artifact):
809+
group_slurm_request, _ = group_slurm_request_with_artifact
810+
group_slurm_request.executor.wait_time_for_group_job = 0
811+
group_slurm_request.executor.run_as_group = True
812+
813+
sbatch_script = group_slurm_request.materialize()
814+
815+
# Should still have the & pids pattern but no sleep
816+
assert "& pids[0]=$!" in sbatch_script
817+
assert "& pids[1]=$!" in sbatch_script
818+
assert "sleep 0" in sbatch_script # Sleep 0 is included
819+
820+
def test_resource_group_with_different_srun_args(
821+
self, group_resource_req_slurm_request_with_artifact
822+
):
823+
group_req, _ = group_resource_req_slurm_request_with_artifact
824+
825+
# Set different srun_args for each resource group
826+
group_req.executor.resource_group[0].srun_args = ["--cpu-bind=cores"]
827+
group_req.executor.resource_group[1].srun_args = ["--mpi=pmix", "--cpu-bind=none"]
828+
829+
sbatch_script = group_req.materialize()
830+
831+
# Check that each srun command has its specific args
832+
assert "--cpu-bind=cores" in sbatch_script
833+
assert "--mpi=pmix --cpu-bind=none" in sbatch_script
834+
835+
def test_signal_parameter(self, dummy_slurm_request_with_artifact):
836+
dummy_slurm_request, _ = dummy_slurm_request_with_artifact
837+
dummy_slurm_request.executor.signal = "USR1@60"
838+
sbatch_script = dummy_slurm_request.materialize()
839+
assert "#SBATCH --signal=USR1@60" in sbatch_script
840+
841+
def test_container_workdir_override(self, dummy_slurm_request_with_artifact):
842+
dummy_slurm_request, _ = dummy_slurm_request_with_artifact
843+
dummy_slurm_request.executor.container_image = "test_image"
844+
sbatch_script = dummy_slurm_request.materialize()
845+
846+
# Default workdir should be /nemo_run/code
847+
assert "--container-workdir /nemo_run/code" in sbatch_script
848+
849+
def test_memory_measure_with_multiple_jobs(self, group_slurm_request_with_artifact):
850+
group_req, _ = group_slurm_request_with_artifact
851+
group_req.executor.memory_measure = True
852+
group_req.executor.run_as_group = True
853+
854+
sbatch_script = group_req.materialize()
855+
856+
# Should have nvidia-smi monitoring
857+
assert "nvidia-smi" in sbatch_script
858+
assert "--overlap" in sbatch_script

0 commit comments

Comments
 (0)