Skip to content

Commit 143b17c

Browse files
authored
Allow customizing folder for SlurmRayRequest (#281)
* Allow customizing folder for SlurmRayRequest Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> * ray log prefix Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]>
1 parent 8d46493 commit 143b17c

File tree

4 files changed

+128
-6
lines changed

4 files changed

+128
-6
lines changed

nemo_run/core/execution/slurm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class SlurmJobDetails:
6161

6262
job_name: Optional[str] = None
6363
folder: Optional[str] = None
64+
ray_log_prefix: str = "ray-"
6465

6566
@property
6667
def stderr(self) -> Path:

nemo_run/run/ray/slurm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def materialize(self) -> str:
152152
if not job_details.folder:
153153
job_details.folder = os.path.join(slurm_job_dir, "logs")
154154

155+
logs_dir: str = job_details.folder # Single source of truth for log-dir inside this SBATCH
156+
155157
parameters["job_name"] = job_details.job_name
156158

157159
stdout = str(job_details.stdout)
@@ -219,10 +221,11 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
219221

220222
return " ".join(_srun_flags)
221223

224+
ray_log_prefix = job_details.ray_log_prefix
222225
vars_to_fill = {
223226
"sbatch_flags": sbatch_flags,
224227
"cluster_dir": self.cluster_dir,
225-
"log_dir": os.path.join(self.cluster_dir, "logs"),
228+
"log_dir": logs_dir,
226229
"uv_cache_dir": os.path.join(self.cluster_dir, "uv_cache"),
227230
"num_retries": max(1, self.executor.retries),
228231
"env_vars": env_vars,
@@ -233,6 +236,7 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
233236
"command": self.command,
234237
"command_workdir": self.workdir,
235238
"gres_specification": get_gres_specification(),
239+
"ray_log_prefix": ray_log_prefix,
236240
}
237241

238242
if self.command_groups:
@@ -260,12 +264,12 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
260264
srun_args.extend(self.executor.srun_args or [])
261265
group_env_vars.append([])
262266

263-
stdout_path = os.path.join(self.cluster_dir, "logs", f"ray-overlap-{idx}.out")
267+
stdout_path = os.path.join(logs_dir, f"{ray_log_prefix}overlap-{idx}.out")
264268
stderr_flags = []
265269
if not self.executor.stderr_to_stdout:
266270
stderr_flags = [
267271
"--error",
268-
os.path.join(self.cluster_dir, "logs", f"ray-overlap-{idx}.err"),
272+
os.path.join(logs_dir, f"{ray_log_prefix}overlap-{idx}.err"),
269273
]
270274

271275
srun_cmd = " ".join(

nemo_run/run/ray/templates/ray.sub.j2

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ touch $LOG_DIR/ENDED
165165
exit 1
166166
EOF
167167
)
168-
srun {{ common_srun_args }} --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-head.log bash -x -c "$head_cmd" &
168+
srun {{ common_srun_args }} --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}head.log bash -x -c "$head_cmd" &
169169

170170
# Wait for the head node container to start and for Ray to be ready
171171
while ! (srun --overlap --nodes=1 --ntasks=1 -w $head_node test -f $LOG_DIR/STARTED_RAY_HEAD && srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w $head_node ray status --address $ip_head 2>/dev/null); do
@@ -241,7 +241,7 @@ EOF
241241
if [[ $i -eq 0 ]]; then
242242
OVERLAP_HEAD_AND_WORKER_ARG="--overlap"
243243
fi
244-
srun {{ common_srun_args }} ${OVERLAP_HEAD_AND_WORKER_ARG:-} --container-name=ray-worker-$i --exact --nodes=1 --ntasks=1 --cpus-per-task=$((16 * gpus_per_node)) -w "$node_i" -o $LOG_DIR/ray-worker-$i.log bash -x -c "$worker_cmd" &
244+
srun {{ common_srun_args }} ${OVERLAP_HEAD_AND_WORKER_ARG:-} --container-name=ray-worker-$i --exact --nodes=1 --ntasks=1 --cpus-per-task=$((16 * gpus_per_node)) -w "$node_i" -o $LOG_DIR/{{ ray_log_prefix }}worker-$i.log bash -x -c "$worker_cmd" &
245245
sleep 3
246246
done
247247

@@ -318,7 +318,7 @@ COMMAND="${COMMAND:-{{ command | default('', true) }}}"
318318
COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }}
319319

320320
if [[ -n "$COMMAND" ]]; then
321-
srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-job.log bash -c "$COMMAND"
321+
srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
322322
else
323323
echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:"
324324
cat <<EOF >$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh

test/run/ray/test_slurm_ray_request.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,3 +581,120 @@ def test_group_env_vars_integration(self):
581581

582582
# The template should include group_env_vars for proper env var handling per command
583583
# (The actual env var exports per command happen in the template rendering)
584+
585+
# ------------------------------------------------------------------
586+
# Custom log directory tests (added for log-dir diff)
587+
# ------------------------------------------------------------------
588+
589+
@pytest.fixture()
590+
def custom_log_request(self) -> tuple[SlurmRayRequest, str]:
591+
"""Produce a SlurmRayRequest where ``executor.job_details.folder`` is overridden."""
592+
executor = SlurmExecutor(account="test_account")
593+
tunnel_mock = Mock(spec=SSHTunnel)
594+
tunnel_mock.job_dir = "/tmp/test_jobs"
595+
executor.tunnel = tunnel_mock
596+
597+
custom_logs_dir = "/custom/logs/location"
598+
executor.job_details.folder = custom_logs_dir
599+
600+
req = SlurmRayRequest(
601+
name="test-ray-custom-logs",
602+
cluster_dir="/tmp/test_jobs/test-ray-custom-logs",
603+
template_name="ray.sub.j2",
604+
executor=executor,
605+
command_groups=[["head"], ["echo", "hello"]],
606+
launch_cmd=["sbatch", "--parsable"],
607+
)
608+
609+
return req, custom_logs_dir
610+
611+
def test_log_dir_export_and_sbatch_paths(self, custom_log_request):
612+
"""Ensure that LOG_DIR and SBATCH paths use the custom directory when provided."""
613+
req, custom_logs_dir = custom_log_request
614+
script = req.materialize()
615+
616+
assert f"export LOG_DIR={custom_logs_dir}" in script
617+
assert f"#SBATCH --output={custom_logs_dir}/" in script
618+
assert os.path.join(custom_logs_dir, "ray-overlap-1.out") in script
619+
620+
def test_default_log_dir_fallback(self):
621+
"""Default behaviour: log paths default to <cluster_dir>/logs when not overridden."""
622+
executor = SlurmExecutor(account="test_account")
623+
tunnel_mock = Mock(spec=SSHTunnel)
624+
tunnel_mock.job_dir = "/tmp/test_jobs"
625+
executor.tunnel = tunnel_mock
626+
627+
cluster_dir = "/tmp/test_jobs/default-logs-cluster"
628+
req = SlurmRayRequest(
629+
name="default-logs-cluster",
630+
cluster_dir=cluster_dir,
631+
template_name="ray.sub.j2",
632+
executor=executor,
633+
launch_cmd=["sbatch", "--parsable"],
634+
)
635+
636+
script = req.materialize()
637+
default_logs = os.path.join(cluster_dir, "logs")
638+
assert f"export LOG_DIR={default_logs}" in script
639+
assert f"#SBATCH --output={default_logs}/" in script
640+
641+
def test_default_ray_log_prefix(self):
642+
"""Ensure that default ``ray_log_prefix`` is respected in generated scripts."""
643+
executor = SlurmExecutor(account="test_account")
644+
# Default should be "ray-"
645+
assert executor.job_details.ray_log_prefix == "ray-"
646+
647+
# Attach a mock tunnel so that ``materialize`` works without ssh
648+
tunnel_mock = Mock(spec=SSHTunnel)
649+
tunnel_mock.job_dir = "/tmp/test_jobs"
650+
executor.tunnel = tunnel_mock
651+
652+
req = SlurmRayRequest(
653+
name="default-prefix",
654+
cluster_dir="/tmp/test_jobs/default-prefix",
655+
template_name="ray.sub.j2",
656+
executor=executor,
657+
command_groups=[["head"], ["echo", "hi"]],
658+
launch_cmd=["sbatch", "--parsable"],
659+
)
660+
661+
script = req.materialize()
662+
663+
# Head / worker / overlap log paths must include the default prefix
664+
assert "ray-head.log" in script
665+
assert "ray-worker-" in script
666+
assert "ray-overlap-" in script
667+
assert "ray-job.log" in script
668+
669+
def test_custom_ray_log_prefix(self):
670+
"""Validate that a custom ``ray_log_prefix`` propagates to all log file names."""
671+
executor = SlurmExecutor(account="test_account")
672+
# Override the prefix
673+
custom_prefix = "mycustom-"
674+
executor.job_details.ray_log_prefix = custom_prefix
675+
676+
# Mock tunnel
677+
tunnel_mock = Mock(spec=SSHTunnel)
678+
tunnel_mock.job_dir = "/tmp/test_jobs"
679+
executor.tunnel = tunnel_mock
680+
681+
req = SlurmRayRequest(
682+
name="custom-prefix-cluster",
683+
cluster_dir="/tmp/test_jobs/custom-prefix-cluster",
684+
template_name="ray.sub.j2",
685+
executor=executor,
686+
command_groups=[["head"], ["echo", "hi"]],
687+
launch_cmd=["sbatch", "--parsable"],
688+
)
689+
690+
script = req.materialize()
691+
692+
# All log files generated inside the script should use the custom prefix
693+
expected_patterns = [
694+
f"{custom_prefix}head.log",
695+
f"{custom_prefix}worker-",
696+
f"{custom_prefix}overlap-1.out",
697+
f"{custom_prefix}job.log",
698+
]
699+
for pattern in expected_patterns:
700+
assert pattern in script, f"Log path missing expected prefix pattern: {pattern}"

0 commit comments

Comments
 (0)