Skip to content

Commit e95df71

Browse files
hemildesairyantwolfgithub-advanced-security[bot]
authored
Add SlurmRay launcher and transform API for launchers (#159)
* Add SlurmRay launcher and transform API for launchers Signed-off-by: Hemil Desai <[email protected]> * Modify Slurm Launcher to allow arbitrary scripts (#163) * Alter launcher script Signed-off-by: Ryan Wolf <[email protected]> * Refactor to allow arbitrary template Signed-off-by: Ryan Wolf <[email protected]> * Add custom slurm ray launcher Signed-off-by: Ryan Wolf <[email protected]> * Remove duplicate script Signed-off-by: Ryan Wolf <[email protected]> * Fix command being run Signed-off-by: Ryan Wolf <[email protected]> * Remove curator specific references Signed-off-by: Ryan Wolf <[email protected]> * Rename template Signed-off-by: Ryan Wolf <[email protected]> --------- Signed-off-by: Ryan Wolf <[email protected]> Signed-off-by: Hemil Desai <[email protected]> * fix spelling action Signed-off-by: Hemil Desai <[email protected]> * fixes Signed-off-by: Hemil Desai <[email protected]> * Potential fix for code scanning alert no. 240: Jinja2 templating with autoescape=False Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]> Signed-off-by: Ryan Wolf <[email protected]> Co-authored-by: Ryan Wolf <[email protected]> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
1 parent 4d05653 commit e95df71

File tree

13 files changed

+748
-87
lines changed

13 files changed

+748
-87
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@ extend-exclude = [
44
"test/",
55
]
66
ignore-hidden = false
7+
8+
9+
[default.extend-words]
10+
typ = "typ"

.github/workflows/spelling.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ jobs:
1515
- uses: crate-ci/typos@master
1616
with:
1717
files: .
18-
config: ./.github/workflows/config/typos.yml
18+
config: ./.github/workflows/config/typos.toml

src/nemo_run/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919
from nemo_run.core.execution.base import (
2020
Executor,
2121
ExecutorMacros,
22-
FaultTolerance,
23-
Torchrun,
2422
import_executor,
2523
)
2624
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
2725
from nemo_run.core.execution.docker import DockerExecutor
26+
from nemo_run.core.execution.launcher import FaultTolerance, SlurmRay, SlurmTemplate, Torchrun
2827
from nemo_run.core.execution.local import LocalExecutor
2928
from nemo_run.core.execution.skypilot import SkypilotExecutor
3029
from nemo_run.core.execution.slurm import SlurmExecutor
@@ -69,6 +68,8 @@
6968
"SlurmExecutor",
7069
"SSHTunnel",
7170
"Torchrun",
71+
"SlurmRay",
72+
"SlurmTemplate",
7273
]
7374

7475
try:

src/nemo_run/core/execution/base.py

Lines changed: 5 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,67 +18,17 @@
1818
import os
1919
from dataclasses import asdict, dataclass, field
2020
from string import Template
21-
from typing import Optional, Protocol, Type, Union, runtime_checkable
21+
from typing import Optional, Protocol, Union, runtime_checkable
2222

2323
import fiddle as fdl
2424
from torchx.specs import Role
2525
from typing_extensions import Self
2626

2727
from nemo_run.config import NEMORUN_HOME, ConfigurableMixin
28+
from nemo_run.core.execution.launcher import LAUNCHER_MAP, Launcher
2829
from nemo_run.core.packaging.base import Packager
2930

3031

31-
@dataclass(kw_only=True)
32-
class Launcher(ConfigurableMixin):
33-
nsys_profile: bool = False
34-
nsys_folder: str = "nsys_profile"
35-
nsys_trace: list[str] = field(default_factory=lambda: ["nvtx", "cuda"])
36-
37-
def get_nsys_prefix(self, profile_dir: str) -> Optional[list[str]]:
38-
"""Make a command prefix for nsys profiling"""
39-
if self.nsys_profile:
40-
profile_out_path = os.path.join(profile_dir, self.nsys_folder)
41-
args = [
42-
"profile",
43-
"-s",
44-
"none",
45-
"-t",
46-
",".join(self.nsys_trace),
47-
"-o",
48-
f"{profile_out_path}/profile_%p",
49-
"--force-overwrite",
50-
"true",
51-
"--capture-range=cudaProfilerApi",
52-
"--capture-range-end=stop",
53-
"--cuda-graph-trace=node",
54-
]
55-
return args
56-
57-
58-
@dataclass(kw_only=True)
59-
class Torchrun(Launcher):
60-
rdzv_backend: str = "c10d"
61-
rdzv_port: int = 29500
62-
63-
64-
@dataclass(kw_only=True)
65-
class FaultTolerance(Launcher):
66-
cfg_path: str = ""
67-
finished_flag_file: str = ""
68-
job_results_file: str = ""
69-
rdzv_backend: str = "c10d"
70-
rdzv_port: int = 29500
71-
workload_check_interval: Optional[float] = None
72-
initial_rank_heartbeat_timeout: Optional[float] = None
73-
rank_heartbeat_timeout: Optional[float] = None
74-
rank_termination_signal: Optional[str] = None
75-
log_level: Optional[str] = None
76-
max_restarts: Optional[int] = None
77-
78-
79-
LAUNCHER_MAP: dict[str, Type[Launcher]] = {"torchrun": Torchrun, "ft": FaultTolerance}
80-
81-
8232
@dataclass(kw_only=True)
8333
class ExecutorMacros(ConfigurableMixin):
8434
"""
@@ -215,6 +165,9 @@ def get_launcher_prefix(self) -> Optional[list[str]]:
215165
os.makedirs(os.path.join(self.job_dir, launcher.nsys_folder), exist_ok=True)
216166
return launcher.get_nsys_prefix(profile_dir=self.job_dir)
217167

168+
def supports_launcher_transform(self) -> bool:
169+
return False
170+
218171
def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
219172
filenames = []
220173
basepath = os.path.join(self.job_dir, "configs")
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import os
2+
import pathlib
3+
from dataclasses import dataclass, field
4+
from typing import Optional, Type
5+
6+
import jinja2
7+
8+
from nemo_run.config import ConfigurableMixin, Script
9+
from nemo_run.core.execution.utils import fill_template
10+
11+
12+
@dataclass(kw_only=True)
13+
class Launcher(ConfigurableMixin):
14+
nsys_profile: bool = False
15+
nsys_folder: str = "nsys_profile"
16+
nsys_trace: list[str] = field(default_factory=lambda: ["nvtx", "cuda"])
17+
18+
def get_nsys_prefix(self, profile_dir: str) -> Optional[list[str]]:
19+
"""Make a command prefix for nsys profiling"""
20+
if self.nsys_profile:
21+
profile_out_path = os.path.join(profile_dir, self.nsys_folder)
22+
args = [
23+
"profile",
24+
"-s",
25+
"none",
26+
"-t",
27+
",".join(self.nsys_trace),
28+
"-o",
29+
f"{profile_out_path}/profile_%p",
30+
"--force-overwrite",
31+
"true",
32+
"--capture-range=cudaProfilerApi",
33+
"--capture-range-end=stop",
34+
"--cuda-graph-trace=node",
35+
]
36+
return args
37+
38+
def transform(self, cmd: list[str]) -> Optional[Script]: ...
39+
40+
41+
@dataclass(kw_only=True)
42+
class Torchrun(Launcher):
43+
rdzv_backend: str = "c10d"
44+
rdzv_port: int = 29500
45+
46+
47+
@dataclass(kw_only=True)
48+
class FaultTolerance(Launcher):
49+
cfg_path: str = ""
50+
finished_flag_file: str = ""
51+
job_results_file: str = ""
52+
rdzv_backend: str = "c10d"
53+
rdzv_port: int = 29500
54+
workload_check_interval: Optional[float] = None
55+
initial_rank_heartbeat_timeout: Optional[float] = None
56+
rank_heartbeat_timeout: Optional[float] = None
57+
rank_termination_signal: Optional[str] = None
58+
log_level: Optional[str] = None
59+
max_restarts: Optional[int] = None
60+
61+
62+
@dataclass(kw_only=True)
63+
class SlurmTemplate(Launcher):
64+
"""
65+
A generic launcher that uses Jinja2 templates to wrap commands.
66+
The template can be provided either as inline content or as a path to a template file.
67+
"""
68+
69+
template_path: Optional[str] = None
70+
template_inline: Optional[str] = None
71+
template_vars: dict = field(default_factory=dict)
72+
73+
def __post_init__(self):
74+
# Ensure at least one template source is provided
75+
if not self.template_path and not self.template_inline:
76+
raise ValueError("Either template_path or template_inline must be provided")
77+
78+
def get_template_content(self) -> str:
79+
"""
80+
Get the template content either from the file or inline content.
81+
"""
82+
if self.template_inline:
83+
return self.template_inline
84+
85+
if self.template_path:
86+
# Check if the path is absolute
87+
path = pathlib.Path(self.template_path)
88+
if path.is_absolute():
89+
# Read the template from the absolute path
90+
with open(path, "r") as f:
91+
return f.read()
92+
else:
93+
# Use the template from the templates directory
94+
template_dir = os.path.join(os.path.dirname(__file__), "templates")
95+
template_path = os.path.join(template_dir, self.template_path)
96+
if os.path.exists(template_path):
97+
with open(template_path, "r") as f:
98+
return f.read()
99+
else:
100+
raise FileNotFoundError(f'Template "{self.template_path}" does not exist.')
101+
102+
# This should not happen due to the check in __post_init__
103+
raise ValueError("No template available")
104+
105+
def render_template(self, cmd: list[str]) -> str:
106+
"""
107+
Render the template with the command and additional variables.
108+
"""
109+
# If using a template file from the templates directory
110+
if self.template_path and not os.path.isabs(self.template_path):
111+
# Create variables dictionary with command and additional variables
112+
vars_dict = {"command": " ".join(cmd), **self.template_vars}
113+
# Use the project's template rendering utility
114+
return fill_template(self.template_path, vars_dict)
115+
116+
# If using inline template or absolute path template
117+
template_content = self.get_template_content()
118+
env = jinja2.Environment(autoescape=jinja2.select_autoescape(["html", "xml"]))
119+
template = env.from_string(template_content)
120+
121+
# Create variables dictionary with command and additional variables
122+
vars_dict = {"command": " ".join(cmd), **self.template_vars}
123+
124+
# Render the template
125+
return template.render(**vars_dict)
126+
127+
def transform(self, cmd: list[str]) -> Optional[Script]:
128+
"""
129+
Transform the command using the template.
130+
"""
131+
rendered_script = self.render_template(cmd)
132+
return Script(inline=rendered_script)
133+
134+
135+
@dataclass(kw_only=True)
136+
class SlurmRay(SlurmTemplate):
137+
"""
138+
Transforms a provided cmd into a Ray launcher bash script for SlurmExecutor.
139+
The Ray launcher script sets up a Ray cluster on Slurm nodes, with the head node starting Ray head
140+
and executing the provided command. Worker nodes start Ray and wait.
141+
"""
142+
143+
gcs_server_port: int = 6379
144+
dashboard_port: int = 8265
145+
object_manager_port: int = 8076
146+
node_manager_port: int = 8077
147+
dashboard_agent_port: int = 52365
148+
dashboard_agent_grpc_port: int = 52366
149+
metrics_port: int = 9002
150+
display_nvidia_smi_output: bool = False
151+
head_setup: Optional[str] = None
152+
head_init_wait_time: int = 10
153+
worker_init_wait_time: int = 60
154+
env_vars: Optional[dict] = None
155+
156+
def __post_init__(self):
157+
# Set the template path to the Ray template
158+
self.template_path = "slurm_ray.sh.j2"
159+
# Fill in the template variables
160+
self.template_vars["gcs_server_port"] = self.gcs_server_port
161+
self.template_vars["dashboard_port"] = self.dashboard_port
162+
self.template_vars["object_manager_port"] = self.object_manager_port
163+
self.template_vars["node_manager_port"] = self.node_manager_port
164+
self.template_vars["dashboard_agent_port"] = self.dashboard_agent_port
165+
self.template_vars["dashboard_agent_grpc_port"] = self.dashboard_agent_grpc_port
166+
self.template_vars["metrics_port"] = self.metrics_port
167+
self.template_vars["display_nvidia_smi_output"] = self.display_nvidia_smi_output
168+
self.template_vars["head_setup"] = self.head_setup
169+
self.template_vars["head_init_wait_time"] = self.head_init_wait_time
170+
self.template_vars["worker_init_wait_time"] = self.worker_init_wait_time
171+
if self.env_vars:
172+
self.template_vars["env_vars"] = "\n".join(
173+
[f'export {k}="{v}"' for k, v in self.env_vars.items()]
174+
)
175+
# Call parent's post_init
176+
super().__post_init__()
177+
178+
179+
LAUNCHER_MAP: dict[str, Type[Launcher]] = {
180+
"torchrun": Torchrun,
181+
"ft": FaultTolerance,
182+
"slurm_ray": SlurmRay,
183+
"slurm_template": SlurmTemplate,
184+
}

src/nemo_run/core/execution/skypilot.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@
2626
from nemo_run.core.execution.base import (
2727
Executor,
2828
ExecutorMacros,
29-
FaultTolerance,
30-
Torchrun,
3129
)
30+
from nemo_run.core.execution.launcher import FaultTolerance, Torchrun
3231
from nemo_run.core.packaging.base import Packager
3332
from nemo_run.core.packaging.git import GitArchivePackager
3433

src/nemo_run/core/execution/slurm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,8 @@
3333
from nemo_run.core.execution.base import (
3434
Executor,
3535
ExecutorMacros,
36-
FaultTolerance,
37-
Launcher,
38-
Torchrun,
3936
)
37+
from nemo_run.core.execution.launcher import FaultTolerance, Launcher, SlurmTemplate, Torchrun
4038
from nemo_run.core.execution.utils import fill_template
4139
from nemo_run.core.frontend.console.api import CONSOLE
4240
from nemo_run.core.packaging.base import Packager
@@ -544,6 +542,9 @@ def get_launcher_prefix(self) -> Optional[list[str]]:
544542
if launcher.nsys_profile:
545543
return launcher.get_nsys_prefix(profile_dir=f"/{RUNDIR_NAME}")
546544

545+
def supports_launcher_transform(self) -> bool:
546+
return True if isinstance(self.get_launcher(), SlurmTemplate) else False
547+
547548
def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
548549
filenames = []
549550
basepath = os.path.join(self.job_dir, "configs")
@@ -825,7 +826,9 @@ def materialize(self) -> str:
825826

826827
sbatch_flags = []
827828
if self.slurm_config.heterogeneous:
828-
assert len(self.jobs) == len(self.slurm_config.resource_group)
829+
assert (
830+
len(self.jobs) == len(self.slurm_config.resource_group)
831+
), f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.slurm_config.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor."
829832
final_group_index = len(self.slurm_config.resource_group) - 1
830833
if self.slurm_config.het_group_indices:
831834
final_group_index = self.slurm_config.het_group_indices.index(

0 commit comments

Comments
 (0)