|
| 1 | +import os |
| 2 | +import pathlib |
| 3 | +from dataclasses import dataclass, field |
| 4 | +from typing import Optional, Type |
| 5 | + |
| 6 | +import jinja2 |
| 7 | + |
| 8 | +from nemo_run.config import ConfigurableMixin, Script |
| 9 | +from nemo_run.core.execution.utils import fill_template |
| 10 | + |
| 11 | + |
| 12 | +@dataclass(kw_only=True) |
| 13 | +class Launcher(ConfigurableMixin): |
| 14 | + nsys_profile: bool = False |
| 15 | + nsys_folder: str = "nsys_profile" |
| 16 | + nsys_trace: list[str] = field(default_factory=lambda: ["nvtx", "cuda"]) |
| 17 | + |
| 18 | + def get_nsys_prefix(self, profile_dir: str) -> Optional[list[str]]: |
| 19 | + """Make a command prefix for nsys profiling""" |
| 20 | + if self.nsys_profile: |
| 21 | + profile_out_path = os.path.join(profile_dir, self.nsys_folder) |
| 22 | + args = [ |
| 23 | + "profile", |
| 24 | + "-s", |
| 25 | + "none", |
| 26 | + "-t", |
| 27 | + ",".join(self.nsys_trace), |
| 28 | + "-o", |
| 29 | + f"{profile_out_path}/profile_%p", |
| 30 | + "--force-overwrite", |
| 31 | + "true", |
| 32 | + "--capture-range=cudaProfilerApi", |
| 33 | + "--capture-range-end=stop", |
| 34 | + "--cuda-graph-trace=node", |
| 35 | + ] |
| 36 | + return args |
| 37 | + |
| 38 | + def transform(self, cmd: list[str]) -> Optional[Script]: ... |
| 39 | + |
| 40 | + |
| 41 | +@dataclass(kw_only=True) |
| 42 | +class Torchrun(Launcher): |
| 43 | + rdzv_backend: str = "c10d" |
| 44 | + rdzv_port: int = 29500 |
| 45 | + |
| 46 | + |
| 47 | +@dataclass(kw_only=True) |
| 48 | +class FaultTolerance(Launcher): |
| 49 | + cfg_path: str = "" |
| 50 | + finished_flag_file: str = "" |
| 51 | + job_results_file: str = "" |
| 52 | + rdzv_backend: str = "c10d" |
| 53 | + rdzv_port: int = 29500 |
| 54 | + workload_check_interval: Optional[float] = None |
| 55 | + initial_rank_heartbeat_timeout: Optional[float] = None |
| 56 | + rank_heartbeat_timeout: Optional[float] = None |
| 57 | + rank_termination_signal: Optional[str] = None |
| 58 | + log_level: Optional[str] = None |
| 59 | + max_restarts: Optional[int] = None |
| 60 | + |
| 61 | + |
| 62 | +@dataclass(kw_only=True) |
| 63 | +class SlurmTemplate(Launcher): |
| 64 | + """ |
| 65 | + A generic launcher that uses Jinja2 templates to wrap commands. |
| 66 | + The template can be provided either as inline content or as a path to a template file. |
| 67 | + """ |
| 68 | + |
| 69 | + template_path: Optional[str] = None |
| 70 | + template_inline: Optional[str] = None |
| 71 | + template_vars: dict = field(default_factory=dict) |
| 72 | + |
| 73 | + def __post_init__(self): |
| 74 | + # Ensure at least one template source is provided |
| 75 | + if not self.template_path and not self.template_inline: |
| 76 | + raise ValueError("Either template_path or template_inline must be provided") |
| 77 | + |
| 78 | + def get_template_content(self) -> str: |
| 79 | + """ |
| 80 | + Get the template content either from the file or inline content. |
| 81 | + """ |
| 82 | + if self.template_inline: |
| 83 | + return self.template_inline |
| 84 | + |
| 85 | + if self.template_path: |
| 86 | + # Check if the path is absolute |
| 87 | + path = pathlib.Path(self.template_path) |
| 88 | + if path.is_absolute(): |
| 89 | + # Read the template from the absolute path |
| 90 | + with open(path, "r") as f: |
| 91 | + return f.read() |
| 92 | + else: |
| 93 | + # Use the template from the templates directory |
| 94 | + template_dir = os.path.join(os.path.dirname(__file__), "templates") |
| 95 | + template_path = os.path.join(template_dir, self.template_path) |
| 96 | + if os.path.exists(template_path): |
| 97 | + with open(template_path, "r") as f: |
| 98 | + return f.read() |
| 99 | + else: |
| 100 | + raise FileNotFoundError(f'Template "{self.template_path}" does not exist.') |
| 101 | + |
| 102 | + # This should not happen due to the check in __post_init__ |
| 103 | + raise ValueError("No template available") |
| 104 | + |
| 105 | + def render_template(self, cmd: list[str]) -> str: |
| 106 | + """ |
| 107 | + Render the template with the command and additional variables. |
| 108 | + """ |
| 109 | + # If using a template file from the templates directory |
| 110 | + if self.template_path and not os.path.isabs(self.template_path): |
| 111 | + # Create variables dictionary with command and additional variables |
| 112 | + vars_dict = {"command": " ".join(cmd), **self.template_vars} |
| 113 | + # Use the project's template rendering utility |
| 114 | + return fill_template(self.template_path, vars_dict) |
| 115 | + |
| 116 | + # If using inline template or absolute path template |
| 117 | + template_content = self.get_template_content() |
| 118 | + env = jinja2.Environment(autoescape=jinja2.select_autoescape(["html", "xml"])) |
| 119 | + template = env.from_string(template_content) |
| 120 | + |
| 121 | + # Create variables dictionary with command and additional variables |
| 122 | + vars_dict = {"command": " ".join(cmd), **self.template_vars} |
| 123 | + |
| 124 | + # Render the template |
| 125 | + return template.render(**vars_dict) |
| 126 | + |
| 127 | + def transform(self, cmd: list[str]) -> Optional[Script]: |
| 128 | + """ |
| 129 | + Transform the command using the template. |
| 130 | + """ |
| 131 | + rendered_script = self.render_template(cmd) |
| 132 | + return Script(inline=rendered_script) |
| 133 | + |
| 134 | + |
| 135 | +@dataclass(kw_only=True) |
| 136 | +class SlurmRay(SlurmTemplate): |
| 137 | + """ |
| 138 | + Transforms a provided cmd into a Ray launcher bash script for SlurmExecutor. |
| 139 | + The Ray launcher script sets up a Ray cluster on Slurm nodes, with the head node starting Ray head |
| 140 | + and executing the provided command. Worker nodes start Ray and wait. |
| 141 | + """ |
| 142 | + |
| 143 | + gcs_server_port: int = 6379 |
| 144 | + dashboard_port: int = 8265 |
| 145 | + object_manager_port: int = 8076 |
| 146 | + node_manager_port: int = 8077 |
| 147 | + dashboard_agent_port: int = 52365 |
| 148 | + dashboard_agent_grpc_port: int = 52366 |
| 149 | + metrics_port: int = 9002 |
| 150 | + display_nvidia_smi_output: bool = False |
| 151 | + head_setup: Optional[str] = None |
| 152 | + head_init_wait_time: int = 10 |
| 153 | + worker_init_wait_time: int = 60 |
| 154 | + env_vars: Optional[dict] = None |
| 155 | + |
| 156 | + def __post_init__(self): |
| 157 | + # Set the template path to the Ray template |
| 158 | + self.template_path = "slurm_ray.sh.j2" |
| 159 | + # Fill in the template variables |
| 160 | + self.template_vars["gcs_server_port"] = self.gcs_server_port |
| 161 | + self.template_vars["dashboard_port"] = self.dashboard_port |
| 162 | + self.template_vars["object_manager_port"] = self.object_manager_port |
| 163 | + self.template_vars["node_manager_port"] = self.node_manager_port |
| 164 | + self.template_vars["dashboard_agent_port"] = self.dashboard_agent_port |
| 165 | + self.template_vars["dashboard_agent_grpc_port"] = self.dashboard_agent_grpc_port |
| 166 | + self.template_vars["metrics_port"] = self.metrics_port |
| 167 | + self.template_vars["display_nvidia_smi_output"] = self.display_nvidia_smi_output |
| 168 | + self.template_vars["head_setup"] = self.head_setup |
| 169 | + self.template_vars["head_init_wait_time"] = self.head_init_wait_time |
| 170 | + self.template_vars["worker_init_wait_time"] = self.worker_init_wait_time |
| 171 | + if self.env_vars: |
| 172 | + self.template_vars["env_vars"] = "\n".join( |
| 173 | + [f'export {k}="{v}"' for k, v in self.env_vars.items()] |
| 174 | + ) |
| 175 | + # Call parent's post_init |
| 176 | + super().__post_init__() |
| 177 | + |
| 178 | + |
| 179 | +LAUNCHER_MAP: dict[str, Type[Launcher]] = { |
| 180 | + "torchrun": Torchrun, |
| 181 | + "ft": FaultTolerance, |
| 182 | + "slurm_ray": SlurmRay, |
| 183 | + "slurm_template": SlurmTemplate, |
| 184 | +} |
0 commit comments