|
1 | 1 | import os |
| 2 | +import pathlib |
2 | 3 | from dataclasses import dataclass, field |
3 | 4 | from typing import Optional, Type |
4 | 5 |
|
| 6 | +import jinja2 |
| 7 | + |
5 | 8 | from nemo_run.config import ConfigurableMixin, Script |
| 9 | +from nemo_run.core.execution.utils import fill_template |
6 | 10 |
|
7 | 11 |
|
8 | 12 | @dataclass(kw_only=True) |
@@ -56,73 +60,124 @@ class FaultTolerance(Launcher): |
56 | 60 |
|
57 | 61 |
|
58 | 62 | @dataclass(kw_only=True) |
59 | | -class SlurmRay(Launcher): |
| 63 | +class SlurmTemplate(Launcher): |
60 | 64 | """ |
61 | | - Transforms a provided cmd into a Ray launcher bash script for SlurmExecutor. |
62 | | - The Ray launcher script sets up a Ray cluster on Slurm nodes, with the head node starting Ray head |
63 | | - and executing the provided command. Worker nodes start Ray and wait. |
| 65 | + A generic launcher that uses Jinja2 templates to wrap commands. |
| 66 | + The template can be provided either as inline content or as a path to a template file. |
64 | 67 | """ |
65 | 68 |
|
66 | | - port: int = 6379 |
| 69 | + template_path: Optional[str] = None |
| 70 | + template_inline: Optional[str] = None |
| 71 | + template_vars: dict = field(default_factory=dict) |
| 72 | + |
| 73 | + def __post_init__(self): |
| 74 | + # Ensure at least one template source is provided |
| 75 | + if not self.template_path and not self.template_inline: |
| 76 | + raise ValueError("Either template_path or template_inline must be provided") |
| 77 | + |
| 78 | + def get_template_content(self) -> str: |
| 79 | + """ |
| 80 | + Get the template content either from the file or inline content. |
| 81 | + """ |
| 82 | + if self.template_inline: |
| 83 | + return self.template_inline |
| 84 | + |
| 85 | + if self.template_path: |
| 86 | + # Check if the path is absolute |
| 87 | + path = pathlib.Path(self.template_path) |
| 88 | + if path.is_absolute(): |
| 89 | + # Read the template from the absolute path |
| 90 | + with open(path, "r") as f: |
| 91 | + return f.read() |
| 92 | + else: |
| 93 | + # Use the template from the templates directory |
| 94 | + template_dir = os.path.join(os.path.dirname(__file__), "templates") |
| 95 | + template_path = os.path.join(template_dir, self.template_path) |
| 96 | + if os.path.exists(template_path): |
| 97 | + with open(template_path, "r") as f: |
| 98 | + return f.read() |
| 99 | + else: |
| 100 | + raise FileNotFoundError(f'Template "{self.template_path}" does not exist.') |
| 101 | + |
| 102 | + # This should not happen due to the check in __post_init__ |
| 103 | + raise ValueError("No template available") |
| 104 | + |
| 105 | + def render_template(self, cmd: list[str]) -> str: |
| 106 | + """ |
| 107 | + Render the template with the command and additional variables. |
| 108 | + """ |
| 109 | + # If using a template file from the templates directory |
| 110 | + if self.template_path and not os.path.isabs(self.template_path): |
| 111 | + # Create variables dictionary with command and additional variables |
| 112 | + vars_dict = {"command": " ".join(cmd), **self.template_vars} |
| 113 | + # Use the project's template rendering utility |
| 114 | + return fill_template(self.template_path, vars_dict) |
| 115 | + |
| 116 | + # If using inline template or absolute path template |
| 117 | + template_content = self.get_template_content() |
| 118 | + template = jinja2.Template(template_content) |
| 119 | + |
| 120 | + # Create variables dictionary with command and additional variables |
| 121 | + vars_dict = {"command": " ".join(cmd), **self.template_vars} |
| 122 | + |
| 123 | + # Render the template |
| 124 | + return template.render(**vars_dict) |
67 | 125 |
|
68 | 126 | def transform(self, cmd: list[str]) -> Optional[Script]: |
69 | 127 | """ |
70 | | - Transforms the provided cmd into a Ray launcher bash script for SlurmExecutor. |
| 128 | + Transform the command using the template. |
71 | 129 | """ |
72 | | - cmd_to_run = " ".join(cmd) |
73 | | - # Build the Ray launcher bash script. Braces in shell variables are escaped as {{ and }} |
74 | | - ray_script = f""" |
75 | | -# Check that a command was provided. |
76 | | -if [ "$#" -lt 1 ]; then |
77 | | - echo "Usage: $0 <command>" |
78 | | - exit 1 |
79 | | -fi |
80 | | -
|
81 | | -# Function to start the Ray head node. |
82 | | -start_head() {{ |
83 | | - echo "Starting Ray head node on ${{HEAD_IP}}" |
84 | | - ray start --head --node-ip-address=${{HEAD_IP}} --port={self.port} |
85 | | - export RAY_ADDRESS="${{HEAD_IP}}:{self.port}" |
86 | | -}} |
87 | | -
|
88 | | -# Function to start a Ray worker node. |
89 | | -start_worker() {{ |
90 | | - # Obtain the head node's hostname from the SLURM_NODELIST. |
91 | | - echo "Starting Ray worker node. Connecting to head ${{HEAD_IP}}" |
92 | | - ray start --address=${{HEAD_IP}}:{self.port} |
93 | | -}} |
94 | | -
|
95 | | -# If this is the head node, start the head; otherwise, start a worker. |
96 | | -if [ -z "$SLURM_NODEID" ] || [ "$SLURM_NODEID" == "0" ]; then |
97 | | - start_head |
98 | | -else |
99 | | - start_worker |
100 | | -fi |
101 | | -
|
102 | | -# Only the head node executes the command. |
103 | | -if [ -z "$SLURM_NODEID" ] || [ "$SLURM_NODEID" == "0" ]; then |
104 | | - echo "Running command: {cmd_to_run}" |
105 | | - # Use eval so the given command is executed with its arguments. |
106 | | - eval "{cmd_to_run}" |
107 | | - echo "Command finished. Shutting down Ray on head node." |
108 | | - ray stop |
109 | | - # Optionally, you could touch a file to signal the worker nodes to shut down. |
110 | | -fi |
111 | | -
|
112 | | -# For worker nodes, simply wait so that Ray stays active. |
113 | | -if [ -n "$SLURM_NODEID" ] && [ "$SLURM_NODEID" != "0" ]; then |
114 | | - echo "Worker node running. Waiting for the Ray head to finish." |
115 | | - while true; do |
116 | | - sleep 15 |
117 | | - done |
118 | | -fi |
119 | | -""" |
120 | | - # Return a new Script object with the inline content |
121 | | - return Script(inline=ray_script) |
| 130 | + rendered_script = self.render_template(cmd) |
| 131 | + return Script(inline=rendered_script) |
| 132 | + |
| 133 | + |
| 134 | +@dataclass(kw_only=True) |
| 135 | +class SlurmRay(SlurmTemplate): |
| 136 | + """ |
| 137 | + Transforms a provided cmd into a Ray launcher bash script for SlurmExecutor. |
| 138 | + The Ray launcher script sets up a Ray cluster on Slurm nodes, with the head node starting Ray head |
| 139 | + and executing the provided command. Worker nodes start Ray and wait. |
| 140 | + """ |
| 141 | + |
| 142 | + gcs_server_port: int = 6379 |
| 143 | + dashboard_port: int = 8265 |
| 144 | + object_manager_port: int = 8076 |
| 145 | + node_manager_port: int = 8077 |
| 146 | + dashboard_agent_port: int = 52365 |
| 147 | + dashboard_agent_grpc_port: int = 52366 |
| 148 | + metrics_port: int = 9002 |
| 149 | + display_nvidia_smi_output: bool = False |
| 150 | + head_setup: Optional[str] = None |
| 151 | + head_init_wait_time: int = 10 |
| 152 | + worker_init_wait_time: int = 60 |
| 153 | + env_vars: Optional[dict] = None |
| 154 | + |
| 155 | + def __post_init__(self): |
| 156 | + # Set the template path to the Ray template |
| 157 | + self.template_path = "slurm_ray.sh.j2" |
| 158 | + # Fill in the template variables |
| 159 | + self.template_vars["gcs_server_port"] = self.gcs_server_port |
| 160 | + self.template_vars["dashboard_port"] = self.dashboard_port |
| 161 | + self.template_vars["object_manager_port"] = self.object_manager_port |
| 162 | + self.template_vars["node_manager_port"] = self.node_manager_port |
| 163 | + self.template_vars["dashboard_agent_port"] = self.dashboard_agent_port |
| 164 | + self.template_vars["dashboard_agent_grpc_port"] = self.dashboard_agent_grpc_port |
| 165 | + self.template_vars["metrics_port"] = self.metrics_port |
| 166 | + self.template_vars["display_nvidia_smi_output"] = self.display_nvidia_smi_output |
| 167 | + self.template_vars["head_setup"] = self.head_setup |
| 168 | + self.template_vars["head_init_wait_time"] = self.head_init_wait_time |
| 169 | + self.template_vars["worker_init_wait_time"] = self.worker_init_wait_time |
| 170 | + if self.env_vars: |
| 171 | + self.template_vars["env_vars"] = "\n".join( |
| 172 | + [f'export {k}="{v}"' for k, v in self.env_vars.items()] |
| 173 | + ) |
| 174 | + # Call parent's post_init |
| 175 | + super().__post_init__() |
122 | 176 |
|
123 | 177 |
|
124 | 178 | LAUNCHER_MAP: dict[str, Type[Launcher]] = { |
125 | 179 | "torchrun": Torchrun, |
126 | 180 | "ft": FaultTolerance, |
127 | 181 | "slurm_ray": SlurmRay, |
| 182 | + "slurm_template": SlurmTemplate, |
128 | 183 | } |
0 commit comments