Skip to content

Commit d20769a

Browse files
authored
Modify Slurm Launcher to allow arbitrary scripts (#163)
* Alter launcher script Signed-off-by: Ryan Wolf <[email protected]> * Refactor to allow arbitrary template Signed-off-by: Ryan Wolf <[email protected]> * Add custom slurm ray launcher Signed-off-by: Ryan Wolf <[email protected]> * Remove duplicate script Signed-off-by: Ryan Wolf <[email protected]> * Fix command being run Signed-off-by: Ryan Wolf <[email protected]> * Remove curator specific references Signed-off-by: Ryan Wolf <[email protected]> * Rename template Signed-off-by: Ryan Wolf <[email protected]> --------- Signed-off-by: Ryan Wolf <[email protected]>
1 parent 39b021c commit d20769a

File tree

2 files changed

+266
-56
lines changed

2 files changed

+266
-56
lines changed
Lines changed: 111 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import os
2+
import pathlib
23
from dataclasses import dataclass, field
34
from typing import Optional, Type
45

6+
import jinja2
7+
58
from nemo_run.config import ConfigurableMixin, Script
9+
from nemo_run.core.execution.utils import fill_template
610

711

812
@dataclass(kw_only=True)
@@ -56,73 +60,124 @@ class FaultTolerance(Launcher):
5660

5761

5862
@dataclass(kw_only=True)
59-
class SlurmRay(Launcher):
63+
class SlurmTemplate(Launcher):
6064
"""
61-
Transforms a provided cmd into a Ray launcher bash script for SlurmExecutor.
62-
The Ray launcher script sets up a Ray cluster on Slurm nodes, with the head node starting Ray head
63-
and executing the provided command. Worker nodes start Ray and wait.
65+
A generic launcher that uses Jinja2 templates to wrap commands.
66+
The template can be provided either as inline content or as a path to a template file.
6467
"""
6568

66-
port: int = 6379
69+
template_path: Optional[str] = None
70+
template_inline: Optional[str] = None
71+
template_vars: dict = field(default_factory=dict)
72+
73+
def __post_init__(self):
74+
# Ensure at least one template source is provided
75+
if not self.template_path and not self.template_inline:
76+
raise ValueError("Either template_path or template_inline must be provided")
77+
78+
def get_template_content(self) -> str:
79+
"""
80+
Get the template content either from the file or inline content.
81+
"""
82+
if self.template_inline:
83+
return self.template_inline
84+
85+
if self.template_path:
86+
# Check if the path is absolute
87+
path = pathlib.Path(self.template_path)
88+
if path.is_absolute():
89+
# Read the template from the absolute path
90+
with open(path, "r") as f:
91+
return f.read()
92+
else:
93+
# Use the template from the templates directory
94+
template_dir = os.path.join(os.path.dirname(__file__), "templates")
95+
template_path = os.path.join(template_dir, self.template_path)
96+
if os.path.exists(template_path):
97+
with open(template_path, "r") as f:
98+
return f.read()
99+
else:
100+
raise FileNotFoundError(f'Template "{self.template_path}" does not exist.')
101+
102+
# This should not happen due to the check in __post_init__
103+
raise ValueError("No template available")
104+
105+
def render_template(self, cmd: list[str]) -> str:
106+
"""
107+
Render the template with the command and additional variables.
108+
"""
109+
# If using a template file from the templates directory
110+
if self.template_path and not os.path.isabs(self.template_path):
111+
# Create variables dictionary with command and additional variables
112+
vars_dict = {"command": " ".join(cmd), **self.template_vars}
113+
# Use the project's template rendering utility
114+
return fill_template(self.template_path, vars_dict)
115+
116+
# If using inline template or absolute path template
117+
template_content = self.get_template_content()
118+
template = jinja2.Template(template_content)
119+
120+
# Create variables dictionary with command and additional variables
121+
vars_dict = {"command": " ".join(cmd), **self.template_vars}
122+
123+
# Render the template
124+
return template.render(**vars_dict)
67125

68126
def transform(self, cmd: list[str]) -> Optional[Script]:
69127
"""
70-
Transforms the provided cmd into a Ray launcher bash script for SlurmExecutor.
128+
Transform the command using the template.
71129
"""
72-
cmd_to_run = " ".join(cmd)
73-
# Build the Ray launcher bash script. Braces in shell variables are escaped as {{ and }}
74-
ray_script = f"""
75-
# Check that a command was provided.
76-
if [ "$#" -lt 1 ]; then
77-
echo "Usage: $0 <command>"
78-
exit 1
79-
fi
80-
81-
# Function to start the Ray head node.
82-
start_head() {{
83-
echo "Starting Ray head node on ${{HEAD_IP}}"
84-
ray start --head --node-ip-address=${{HEAD_IP}} --port={self.port}
85-
export RAY_ADDRESS="${{HEAD_IP}}:{self.port}"
86-
}}
87-
88-
# Function to start a Ray worker node.
89-
start_worker() {{
90-
# Obtain the head node's hostname from the SLURM_NODELIST.
91-
echo "Starting Ray worker node. Connecting to head ${{HEAD_IP}}"
92-
ray start --address=${{HEAD_IP}}:{self.port}
93-
}}
94-
95-
# If this is the head node, start the head; otherwise, start a worker.
96-
if [ -z "$SLURM_NODEID" ] || [ "$SLURM_NODEID" == "0" ]; then
97-
start_head
98-
else
99-
start_worker
100-
fi
101-
102-
# Only the head node executes the command.
103-
if [ -z "$SLURM_NODEID" ] || [ "$SLURM_NODEID" == "0" ]; then
104-
echo "Running command: {cmd_to_run}"
105-
# Use eval so the given command is executed with its arguments.
106-
eval "{cmd_to_run}"
107-
echo "Command finished. Shutting down Ray on head node."
108-
ray stop
109-
# Optionally, you could touch a file to signal the worker nodes to shut down.
110-
fi
111-
112-
# For worker nodes, simply wait so that Ray stays active.
113-
if [ -n "$SLURM_NODEID" ] && [ "$SLURM_NODEID" != "0" ]; then
114-
echo "Worker node running. Waiting for the Ray head to finish."
115-
while true; do
116-
sleep 15
117-
done
118-
fi
119-
"""
120-
# Return a new Script object with the inline content
121-
return Script(inline=ray_script)
130+
rendered_script = self.render_template(cmd)
131+
return Script(inline=rendered_script)
132+
133+
134+
@dataclass(kw_only=True)
135+
class SlurmRay(SlurmTemplate):
136+
"""
137+
Transforms a provided cmd into a Ray launcher bash script for SlurmExecutor.
138+
The Ray launcher script sets up a Ray cluster on Slurm nodes, with the head node starting Ray head
139+
and executing the provided command. Worker nodes start Ray and wait.
140+
"""
141+
142+
gcs_server_port: int = 6379
143+
dashboard_port: int = 8265
144+
object_manager_port: int = 8076
145+
node_manager_port: int = 8077
146+
dashboard_agent_port: int = 52365
147+
dashboard_agent_grpc_port: int = 52366
148+
metrics_port: int = 9002
149+
display_nvidia_smi_output: bool = False
150+
head_setup: Optional[str] = None
151+
head_init_wait_time: int = 10
152+
worker_init_wait_time: int = 60
153+
env_vars: Optional[dict] = None
154+
155+
def __post_init__(self):
156+
# Set the template path to the Ray template
157+
self.template_path = "slurm_ray.sh.j2"
158+
# Fill in the template variables
159+
self.template_vars["gcs_server_port"] = self.gcs_server_port
160+
self.template_vars["dashboard_port"] = self.dashboard_port
161+
self.template_vars["object_manager_port"] = self.object_manager_port
162+
self.template_vars["node_manager_port"] = self.node_manager_port
163+
self.template_vars["dashboard_agent_port"] = self.dashboard_agent_port
164+
self.template_vars["dashboard_agent_grpc_port"] = self.dashboard_agent_grpc_port
165+
self.template_vars["metrics_port"] = self.metrics_port
166+
self.template_vars["display_nvidia_smi_output"] = self.display_nvidia_smi_output
167+
self.template_vars["head_setup"] = self.head_setup
168+
self.template_vars["head_init_wait_time"] = self.head_init_wait_time
169+
self.template_vars["worker_init_wait_time"] = self.worker_init_wait_time
170+
if self.env_vars:
171+
self.template_vars["env_vars"] = "\n".join(
172+
[f'export {k}="{v}"' for k, v in self.env_vars.items()]
173+
)
174+
# Call parent's post_init
175+
super().__post_init__()
122176

123177

124178
LAUNCHER_MAP: dict[str, Type[Launcher]] = {
125179
"torchrun": Torchrun,
126180
"ft": FaultTolerance,
127181
"slurm_ray": SlurmRay,
182+
"slurm_template": SlurmTemplate,
128183
}
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#!/bin/bash
2+
3+
# Required environment variables
4+
REQUIRED_VARS=("SLURM_NNODES" "HEAD_NODE_ADDR")
5+
for var in "${REQUIRED_VARS[@]}"; do
6+
if [ -z "${!var}" ]; then
7+
echo "Error: $var is not set."
8+
exit 1
9+
fi
10+
done
11+
12+
echo "Environment Variables:"
13+
echo "SLURM_NNODES=${SLURM_NNODES}"
14+
echo "HEAD_NODE_ADDR=${HEAD_NODE_ADDR}"
15+
{{ env_vars }}
16+
17+
# Extract Ray ports from environment variables or hardcode them
18+
GCS_SERVER_PORT={{ gcs_server_port }}
19+
DASHBOARD_PORT={{ dashboard_port }}
20+
OBJECT_MANAGER_PORT={{ object_manager_port }}
21+
NODE_MANAGER_PORT={{ node_manager_port }}
22+
RAY_DASHBOARD_AGENT_PORT={{ dashboard_agent_port }}
23+
RAY_DASHBOARD_AGENT_GRPC_PORT={{ dashboard_agent_grpc_port }}
24+
METRICS_PORT={{ metrics_port }}
25+
26+
get_ray_worker_count() {
27+
local ray_status_output
28+
ray_status_output=$(ray status 2>&1) # Capture both stdout and stderr
29+
30+
# Check if the output contains the expected "Active" workers section
31+
if echo "$ray_status_output" | grep -q "Active:"; then
32+
# Extract the number of active workers by counting lines containing "node_"
33+
worker_count=$(echo "$ray_status_output" | awk '/Active:/,/Pending:/' | grep -c "node_")
34+
35+
# Ensure worker_count is valid, otherwise set it to 0
36+
if [[ -z "$worker_count" || "$worker_count" -lt 0 ]]; then
37+
worker_count=0
38+
fi
39+
else
40+
# Handle the case where "ray status" doesn't return valid data yet
41+
worker_count=-1
42+
fi
43+
44+
echo "$worker_count"
45+
}
46+
47+
display_nvidia_smi() {
48+
echo "NVIDIA SMI for $SLURMD_NODENAME"
49+
which nvidia-smi && nvidia-smi || echo "nvidia-smi not in container"
50+
}
51+
52+
ray_pid=""
53+
54+
# Display nvidia-smi output
55+
{% if display_nvidia_smi_output | default(false) %}
56+
display_nvidia_smi
57+
{% endif %}
58+
59+
60+
# Function to start the Ray head node.
61+
start_head() {
62+
# Start Ray head node
63+
64+
echo "Starting Ray head node"
65+
ray start --head \
66+
--node-ip-address=$(hostname -i) \
67+
--port=${GCS_SERVER_PORT} \
68+
--object-manager-port=${OBJECT_MANAGER_PORT} \
69+
--node-manager-port=${NODE_MANAGER_PORT} \
70+
--system-config='{"local_fs_capacity_threshold": 0.90, "object_spilling_config": "{ \"type\": \"filesystem\", \"params\": {\"directory_path\": \"/tmp/ray_spill\", \"buffer_size\": 1000000 } }"}' \
71+
--metrics-export-port=${METRICS_PORT} \
72+
--dashboard-host 0.0.0.0 --include-dashboard 1 \
73+
--disable-usage-stats \
74+
--dashboard-agent-grpc-port=${RAY_DASHBOARD_AGENT_GRPC_PORT} \
75+
--dashboard-agent-listen-port=${RAY_DASHBOARD_AGENT_PORT} | tee -a /tmp/ray.log
76+
ray_pid=$!
77+
echo "Ray head node started with PID $ray_pid"
78+
79+
ready_set=false
80+
# Periodically check Ray status
81+
while true; do
82+
worker_count=$(get_ray_worker_count)
83+
echo "Current workers ready: $worker_count"
84+
if [[ "$worker_count" -eq -1 ]]; then
85+
echo "Ray cluster status not available. Waiting for cluster."
86+
sleep 5
87+
continue
88+
fi
89+
if [[ "$worker_count" -eq 1 && "$ready_set" == "false" ]]; then
90+
echo "Ray cluster is ready. Setting head node pod status to ready."
91+
# TODO: enable once health server is ready
92+
# curl -X POST http://localhost:8000/set-ready -H "Content-Type: application/json" -d '{"status": true}'
93+
touch /tmp/is_ready
94+
95+
# Set ready_set to true after the curl request is sent
96+
ready_set=true
97+
fi
98+
99+
# Proceed only if the worker_count is a valid integer and >= expected_workers
100+
if [[ "$worker_count" -ge "$SLURM_NNODES" ]]; then
101+
echo "Enough workers connected. Proceeding to start the Python command."
102+
break
103+
fi
104+
105+
echo "Waiting for workers to connect..."
106+
sleep {{ head_init_wait_time }}
107+
done
108+
109+
{{ head_setup }}
110+
}
111+
112+
# Function to start a Ray worker node.
113+
start_worker() {
114+
sleep {{ worker_init_wait_time }}
115+
set +x
116+
117+
# Start Ray worker node and connect to head
118+
echo "Starting Ray worker node and connecting to head at ${HEAD_NODE_ADDR}:${GCS_SERVER_PORT}"
119+
ray start --address="${HEAD_NODE_ADDR}:${GCS_SERVER_PORT}" \
120+
--block \
121+
--node-ip-address=$(hostname -i) \
122+
--object-manager-port=${OBJECT_MANAGER_PORT} \
123+
--node-manager-port=${NODE_MANAGER_PORT} \
124+
--metrics-export-port=${METRICS_PORT} \
125+
--dashboard-agent-grpc-port=${RAY_DASHBOARD_AGENT_GRPC_PORT} \
126+
--dashboard-agent-listen-port=${RAY_DASHBOARD_AGENT_PORT} \
127+
--disable-usage-stats
128+
129+
130+
# Check if Ray worker node started successfully by reading the exit code
131+
if [ $? -ne 0 ]; then
132+
echo "Error: Ray worker node failed to start."
133+
exit 1
134+
fi
135+
136+
echo "Ray start --block ... exited"
137+
138+
}
139+
140+
# If this is the head node, start the head; otherwise, start a worker.
141+
if [ -z "$SLURM_NODEID" ] || [ "$SLURM_NODEID" == "0" ]; then
142+
start_head
143+
else
144+
start_worker
145+
fi
146+
147+
# Only the head node executes the Python command.
148+
if [ -z "$SLURM_NODEID" ] || [ "$SLURM_NODEID" == "0" ]; then
149+
echo "Running Python command: {{ command }}"
150+
# Use eval so the given command is executed with its arguments.
151+
eval "{{ command }}"
152+
echo "Python script finished. Shutting down Ray on head node."
153+
ray stop
154+
sleep 30
155+
fi

0 commit comments

Comments
 (0)