Skip to content

Commit 11842c8

Browse files
authored
Support overlapped srun commands in Slurm Ray (#263)
1 parent caf3f12 commit 11842c8

File tree

7 files changed

+272
-5
lines changed

7 files changed

+272
-5
lines changed

nemo_run/core/execution/slurm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,6 @@ def merge(
414414
)
415415
)
416416

417-
main_executor.env_vars = {}
418417
return main_executor
419418

420419
def __post_init__(self):

nemo_run/run/ray/slurm.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class SlurmRayRequest:
117117
command: Optional[str] = None
118118
workdir: Optional[str] = None
119119
nemo_run_dir: Optional[str] = None
120+
command_groups: Optional[list[list[str]]] = None
120121
launch_cmd: list[str]
121122

122123
@staticmethod
@@ -234,6 +235,60 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
234235
"gres_specification": get_gres_specification(),
235236
}
236237

238+
if self.command_groups:
239+
srun_commands: list[str] = []
240+
group_env_vars: list[list[str]] = []
241+
242+
for idx, group in enumerate(self.command_groups):
243+
if idx == 0:
244+
continue
245+
246+
if self.executor.run_as_group and len(self.executor.resource_group) == len(
247+
self.command_groups
248+
):
249+
req = self.executor.resource_group[idx]
250+
env_list = [f"export {k.upper()}={v}" for k, v in req.env_vars.items()]
251+
group_env_vars.append(env_list)
252+
container_flags = get_srun_flags(req.container_mounts, req.container_image)
253+
srun_args = ["--wait=60", "--kill-on-bad-exit=1", "--overlap"]
254+
srun_args.extend(req.srun_args or [])
255+
else:
256+
container_flags = get_srun_flags(
257+
self.executor.container_mounts, self.executor.container_image
258+
)
259+
srun_args = ["--wait=60", "--kill-on-bad-exit=1", "--overlap"]
260+
srun_args.extend(self.executor.srun_args or [])
261+
group_env_vars.append([])
262+
263+
stdout_path = os.path.join(self.cluster_dir, "logs", f"ray-overlap-{idx}.out")
264+
stderr_flags = []
265+
if not self.executor.stderr_to_stdout:
266+
stderr_flags = [
267+
"--error",
268+
os.path.join(self.cluster_dir, "logs", f"ray-overlap-{idx}.err"),
269+
]
270+
271+
srun_cmd = " ".join(
272+
list(
273+
map(
274+
lambda arg: arg if isinstance(arg, noquote) else shlex.quote(arg),
275+
[
276+
"srun",
277+
"--output",
278+
noquote(stdout_path),
279+
*stderr_flags,
280+
container_flags,
281+
*srun_args,
282+
],
283+
)
284+
)
285+
)
286+
command = " ".join(group)
287+
srun_commands.append(f"{srun_cmd} {command} &")
288+
289+
vars_to_fill["srun_commands"] = srun_commands
290+
vars_to_fill["group_env_vars"] = group_env_vars
291+
237292
if self.pre_ray_start_commands:
238293
vars_to_fill["pre_ray_start_commands"] = "\n".join(self.pre_ray_start_commands)
239294

@@ -398,6 +453,7 @@ def create(
398453
dryrun: bool = False,
399454
command: Optional[str] = None,
400455
workdir: Optional[str] = None,
456+
command_groups: Optional[list[list[str]]] = None,
401457
) -> Any:
402458
"""Create (or reuse) a Slurm-backed Ray cluster and return its job-id.
403459
@@ -416,6 +472,9 @@ def create(
416472
Optional command executed after the Ray head node is ready (e.g. ``ray job submit``).
417473
workdir : str | None
418474
Remote working directory that becomes the CWD inside the container.
475+
command_groups : list[list[str]] | None
476+
Additional commands (one per group) executed via ``srun`` with ``--overlap``
477+
after the cluster is started.
419478
420479
Returns
421480
-------
@@ -433,6 +492,7 @@ def create(
433492
pre_ray_start_commands=pre_ray_start_commands,
434493
command=command,
435494
workdir=workdir,
495+
command_groups=command_groups,
436496
launch_cmd=["sbatch", "--requeue", "--parsable", "--dependency=singleton"],
437497
).materialize()
438498

@@ -1094,6 +1154,7 @@ def start(
10941154
runtime_env_yaml: Optional[str] | None = None,
10951155
pre_ray_start_commands: Optional[list[str]] = None,
10961156
dryrun: bool = False,
1157+
command_groups: Optional[list[list[str]]] = None,
10971158
):
10981159
"""Submit a Ray job via Slurm and return a *live* SlurmRayJob helper.
10991160
@@ -1106,6 +1167,7 @@ def start(
11061167
executor=my_slurm_executor,
11071168
command="python train.py",
11081169
workdir="./src",
1170+
command_groups=[["echo", "hello"]],
11091171
)
11101172
"""
11111173
# ------------------------------------------------------------------
@@ -1212,6 +1274,7 @@ def start(
12121274
dryrun=dryrun,
12131275
command=command,
12141276
workdir=remote_workdir,
1277+
command_groups=command_groups,
12151278
)
12161279

12171280
self.job_id = job_id

nemo_run/run/ray/templates/ray.sub.j2

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,20 @@ echo "[INFO] Ray cluster information saved to $CLUSTER_DIR/ray_cluster_info.json
295295

296296
########################################################
297297

298+
{% if srun_commands %}
299+
# Run extra commands
300+
{% for srun_command in srun_commands %}
301+
{%- if loop.index <= group_env_vars|length %}
302+
{%- for env_var in group_env_vars[loop.index - 1] %}
303+
{{env_var}}
304+
{%- endfor %}
305+
{%- endif %}
306+
307+
{{srun_command}}
308+
{% endfor %}
309+
########################################################
310+
{% endif -%}
311+
298312
# We can now launch a job on this cluster
299313
# We do so by launching a driver process on the physical node that the head node is on
300314
# This driver process is responsible for launching a job on the Ray cluster

nemo_run/run/torchx_backend/packaging.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,20 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
233233
assert isinstance(executor, SlurmExecutor), (
234234
f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for SlurmExecutor"
235235
)
236-
assert len(app_def.roles) == 1, "Only one command is supported for Ray jobs."
237236

238237
app_def.metadata = metadata
239238
return app_def
240239

241240

242241
def merge_executables(app_defs: Iterator[specs.AppDef], name: str) -> specs.AppDef:
243242
result = specs.AppDef(name=name, roles=[])
244-
for app_def in app_defs:
243+
result.metadata = {}
244+
for idx, app_def in enumerate(app_defs):
245+
metadata = app_def.metadata or {}
246+
if USE_WITH_RAY_CLUSTER_KEY in metadata:
247+
assert idx == 0, f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for the first command"
248+
249+
result.metadata.update(metadata)
245250
result.roles.extend(app_def.roles)
246251
return result
247252

nemo_run/run/torchx_backend/schedulers/slurm.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,17 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
102102

103103
executor.package(packager=executor.packager, job_name=Path(job_dir).name)
104104

105+
values = executor.macro_values()
106+
105107
if app.metadata and app.metadata.get(USE_WITH_RAY_CLUSTER_KEY, False):
106-
assert len(app.roles) == 1, "Only one command is supported for Ray jobs."
108+
srun_cmds: list[list[str]] = []
109+
110+
for role in app.roles:
111+
if values:
112+
role = values.apply(role)
113+
srun_cmd = [role.entrypoint] + role.args
114+
srun_cmds.append([" ".join(srun_cmd)])
115+
107116
command = [app.roles[0].entrypoint] + app.roles[0].args
108117
req = SlurmRayRequest(
109118
name=app.roles[0].name,
@@ -114,12 +123,12 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
114123
executor=executor,
115124
workdir=f"/{RUNDIR_NAME}/code",
116125
nemo_run_dir=os.path.join(executor.tunnel.job_dir, Path(job_dir).name),
126+
command_groups=srun_cmds,
117127
)
118128
else:
119129
srun_cmds: list[list[str]] = []
120130
jobs = []
121131
envs = {}
122-
values = executor.macro_values()
123132

124133
if values:
125134
executor.env_vars = {

test/core/execution/artifacts/group_resource_req_slurm.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ nodes_array=($nodes)
3030
head_node=${nodes_array[0]}
3131
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
3232

33+
export CUSTOM_ENV_1=some_value_1
3334
export ENV_VAR=value
3435

3536

test/run/ray/test_slurm_ray_request.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,3 +405,179 @@ def test_array_assertion(self):
405405

406406
with pytest.raises(AssertionError, match="array is not supported"):
407407
request.materialize()
408+
409+
def test_command_groups_env_vars(self):
410+
"""Test environment variables are properly set for each command group."""
411+
# Create executor with environment variables
412+
executor = SlurmExecutor(
413+
account="test_account",
414+
env_vars={"GLOBAL_ENV": "global_value"},
415+
)
416+
executor.run_as_group = True
417+
418+
# Create resource groups with different env vars
419+
resource_group = [
420+
SlurmExecutor.ResourceRequest(
421+
packager=Mock(),
422+
nodes=1,
423+
ntasks_per_node=1,
424+
container_image="image1",
425+
env_vars={"GROUP1_ENV": "group1_value"},
426+
container_mounts=["/mount1"],
427+
),
428+
SlurmExecutor.ResourceRequest(
429+
packager=Mock(),
430+
nodes=1,
431+
ntasks_per_node=1,
432+
container_image="image2",
433+
env_vars={"GROUP2_ENV": "group2_value"},
434+
container_mounts=["/mount2"],
435+
),
436+
]
437+
executor.resource_group = resource_group
438+
executor.tunnel = Mock(spec=SSHTunnel)
439+
executor.tunnel.job_dir = "/tmp/test_jobs"
440+
441+
request = SlurmRayRequest(
442+
name="test-ray-cluster",
443+
cluster_dir="/tmp/test_jobs/test-ray-cluster",
444+
template_name="ray.sub.j2",
445+
executor=executor,
446+
command_groups=[["cmd0"], ["cmd1"], ["cmd2"]],
447+
launch_cmd=["sbatch", "--parsable"],
448+
)
449+
450+
script = request.materialize()
451+
452+
# Check global env vars are set in setup section
453+
assert "export GLOBAL_ENV=global_value" in script
454+
455+
# Check that command groups generate srun commands (excluding the first one)
456+
# The template should have a section for srun_commands
457+
assert "# Run extra commands" in script
458+
assert "srun" in script
459+
assert "cmd1" in script # First command group after skipping index 0
460+
assert "cmd2" in script # Second command group
461+
462+
def test_command_groups_without_resource_group(self):
463+
"""Test command groups work without resource groups."""
464+
executor = SlurmExecutor(
465+
account="test_account",
466+
env_vars={"GLOBAL_ENV": "global_value"},
467+
)
468+
executor.tunnel = Mock(spec=SSHTunnel)
469+
executor.tunnel.job_dir = "/tmp/test_jobs"
470+
471+
request = SlurmRayRequest(
472+
name="test-ray-cluster",
473+
cluster_dir="/tmp/test_jobs/test-ray-cluster",
474+
template_name="ray.sub.j2",
475+
executor=executor,
476+
command_groups=[["cmd0"], ["cmd1"]],
477+
launch_cmd=["sbatch", "--parsable"],
478+
)
479+
480+
script = request.materialize()
481+
482+
# Should have global env vars
483+
assert "export GLOBAL_ENV=global_value" in script
484+
485+
# Should have srun commands for overlapping groups (skipping first)
486+
assert "srun" in script
487+
assert "--overlap" in script
488+
assert "cmd1" in script # Second command in the list (index 1)
489+
490+
def test_env_vars_formatting(self):
491+
"""Test that environment variables are properly formatted as export statements."""
492+
executor = SlurmExecutor(
493+
account="test_account",
494+
env_vars={
495+
"VAR_WITH_SPACES": "value with spaces",
496+
"PATH_VAR": "/usr/bin:/usr/local/bin",
497+
"EMPTY_VAR": "",
498+
"NUMBER_VAR": "123",
499+
},
500+
)
501+
executor.tunnel = Mock(spec=SSHTunnel)
502+
executor.tunnel.job_dir = "/tmp/test_jobs"
503+
504+
request = SlurmRayRequest(
505+
name="test-ray-cluster",
506+
cluster_dir="/tmp/test_jobs/test-ray-cluster",
507+
template_name="ray.sub.j2",
508+
executor=executor,
509+
launch_cmd=["sbatch", "--parsable"],
510+
)
511+
512+
script = request.materialize()
513+
514+
# Check all environment variables are properly exported
515+
assert "export VAR_WITH_SPACES=value with spaces" in script
516+
assert "export PATH_VAR=/usr/bin:/usr/local/bin" in script
517+
assert "export EMPTY_VAR=" in script
518+
assert "export NUMBER_VAR=123" in script
519+
520+
def test_group_env_vars_integration(self):
521+
"""Test full integration of group environment variables matching the artifact pattern."""
522+
# This test verifies the behavior seen in group_resource_req_slurm.sh
523+
executor = SlurmExecutor(
524+
account="your_account",
525+
partition="your_partition",
526+
time="00:30:00",
527+
nodes=1,
528+
ntasks_per_node=8,
529+
gpus_per_node=8,
530+
container_image="some-image",
531+
container_mounts=["/some/job/dir/sample_job:/nemo_run"],
532+
env_vars={"ENV_VAR": "value"},
533+
)
534+
executor.run_as_group = True
535+
536+
# Set up resource groups with specific env vars
537+
resource_group = [
538+
# First group (index 0) - for the head/main command
539+
SlurmExecutor.ResourceRequest(
540+
packager=Mock(),
541+
nodes=1,
542+
ntasks_per_node=8,
543+
container_image="some-image",
544+
env_vars={"CUSTOM_ENV_1": "some_value_1"},
545+
container_mounts=["/some/job/dir/sample_job:/nemo_run"],
546+
),
547+
# Second group (index 1)
548+
SlurmExecutor.ResourceRequest(
549+
packager=Mock(),
550+
nodes=1,
551+
ntasks_per_node=8,
552+
container_image="different_container_image",
553+
env_vars={"CUSTOM_ENV_1": "some_value_1"},
554+
container_mounts=["/some/job/dir/sample_job:/nemo_run"],
555+
),
556+
]
557+
executor.resource_group = resource_group
558+
559+
# Mock tunnel
560+
tunnel_mock = Mock(spec=SSHTunnel)
561+
tunnel_mock.job_dir = "/some/job/dir"
562+
executor.tunnel = tunnel_mock
563+
564+
request = SlurmRayRequest(
565+
name="sample_job",
566+
cluster_dir="/some/job/dir/sample_job",
567+
template_name="ray.sub.j2",
568+
executor=executor,
569+
command_groups=[
570+
["bash ./scripts/start_server.sh"],
571+
["bash ./scripts/echo.sh server_host=$het_group_host_0"],
572+
],
573+
launch_cmd=["sbatch", "--parsable"],
574+
)
575+
576+
script = request.materialize()
577+
578+
# Verify the pattern matches the artifact:
579+
# 1. Global env vars should be exported in setup
580+
assert "export ENV_VAR=value" in script
581+
582+
# The template should include group_env_vars for proper env var handling per command
583+
# (The actual env var exports per command happen in the template rendering)

0 commit comments

Comments
 (0)