Skip to content

Commit f287bb3

Browse files
authored
[Feature] Multi-node Ray support for GRPO sota-implementation (#3040)
1 parent 4914302 commit f287bb3

File tree

6 files changed

+141
-4
lines changed

6 files changed

+141
-4
lines changed

sota-implementations/grpo/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,20 @@ The async mode offers better performance by:
134134
- Better throughput
135135
- More flexible buffer management
136136

137+
### Running GRPO on More Than One Node with SLURM
138+
139+
GRPO can be run across more than one node using SLURM, enabling distributed training for moderately scaled workloads.
140+
141+
Two scripts are provided for launching multi-node runs:
142+
143+
- `grpo-sync-multi-node.sbatch`: SLURM job script that launches sync GRPO across multiple nodes using Ray.
144+
- `grpo-async-multi-node.sbatch`: SLURM job script that launches async GRPO across multiple nodes using Ray.
145+
146+
Example Usage:
147+
148+
```bash
149+
sbatch sota-implementations/grpo/grpo-sync-multi-node.sbatch
150+
137151
### KL Divergences in PPO: Reference vs Inference
138152

139153
KL divergence is a key regularization term in policy optimization algorithms like PPO and in LLM post-training. It measures how much the updated policy diverges from a baseline or reference policy, helping to prevent the new policy from drifting too far and ensuring stable learning.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=grpo-async-multi-node
3+
#SBATCH --nodes=2
4+
#SBATCH --ntasks-per-node=1
5+
#SBATCH --cpus-per-task=96
6+
#SBATCH --exclusive
7+
#SBATCH --output=logs/%x.job%j.out
8+
#SBATCH --time=24:00:00
9+
10+
# Exit on any error
11+
set -euo pipefail
12+
13+
# Ensure logs directory exists
14+
mkdir -p logs
15+
16+
# Environment variables
17+
export LIST_TO_STACK=1
18+
export VLLM_USE_V1=0
19+
export RAY_CLUSTER_MANAGED_EXTERNALLY=1
20+
21+
# Run command in Ray cluster
22+
CMD="python grpo-async.py mode=async train_model.num_devices=8 ref_model.num_devices=4 inference_model.num_devices=4"
23+
srun bash run_in_ray_cluster.sh "$CMD"
24+
25+
echo "Job completed"

sota-implementations/grpo/grpo-async.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,11 @@ def train(
112112
model_metadata = vLLMUpdater.get_model_metadata(policy_training)
113113

114114
# Create weight updater with remote LLM
115+
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
115116
weight_updater: vLLMUpdater = make_weight_updater(
116-
master_address="localhost", # Since we're running locally
117+
master_address="localhost"
118+
if not ray_managed_externally
119+
else ray.util.get_node_ip_address(),
117120
master_port=None, # Will auto-assign an open port
118121
model_metadata=model_metadata,
119122
vllm_tp_size=cfg.inference_model.num_devices
@@ -326,7 +329,11 @@ def main(cfg):
326329
ray_init_config["runtime_env"]["env_vars"]
327330
)
328331
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
329-
ray.init(**ray_init_config)
332+
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
333+
if ray_managed_externally:
334+
ray.init(address="auto")
335+
else:
336+
ray.init(**ray_init_config)
330337

331338
# Check if num_devices is set
332339
if cfg.inference_model.num_devices is None:
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=grpo-sync-multi-node
3+
#SBATCH --nodes=2
4+
#SBATCH --ntasks-per-node=1
5+
#SBATCH --cpus-per-task=96
6+
#SBATCH --exclusive
7+
#SBATCH --output=logs/%x.job%j.out
8+
#SBATCH --time=24:00:00
9+
10+
# Exit on any error
11+
set -euo pipefail
12+
13+
# Ensure logs directory exists
14+
mkdir -p logs
15+
16+
# Environment variables
17+
export LIST_TO_STACK=1
18+
export VLLM_USE_V1=0
19+
export RAY_CLUSTER_MANAGED_EXTERNALLY=1
20+
21+
# Run command in Ray cluster
22+
CMD="python grpo-sync.py mode=sync train_model.num_devices=8 ref_model.num_devices=4 inference_model.num_devices=4"
23+
srun bash run_in_ray_cluster.sh "$CMD"
24+
25+
echo "Job completed"

sota-implementations/grpo/grpo-sync.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,11 @@ def train(
113113
model_metadata = vLLMUpdater.get_model_metadata(policy_training)
114114

115115
# Create weight updater with remote LLM
116+
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
116117
weight_updater: vLLMUpdater = make_weight_updater(
117-
master_address="localhost", # Since we're running locally
118+
master_address="localhost"
119+
if not ray_managed_externally
120+
else ray.util.get_node_ip_address(),
118121
master_port=None, # Will auto-assign an open port
119122
model_metadata=model_metadata,
120123
vllm_tp_size=cfg.inference_model.num_devices
@@ -338,7 +341,11 @@ def main(cfg):
338341
ray_init_config["runtime_env"]["env_vars"]
339342
)
340343
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
341-
ray.init(**ray_init_config)
344+
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
345+
if ray_managed_externally:
346+
ray.init(address="auto")
347+
else:
348+
ray.init(**ray_init_config)
342349

343350
# Check if num_devices is set
344351
if cfg.inference_model.num_devices is None:
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/bin/bash
2+
3+
set -euo pipefail
4+
5+
# Get command from argument
6+
CMD="$1"
7+
8+
# Set up Ray cluster configuration
9+
HEAD_NODE=$(scontrol show hostname "$SLURM_NODELIST" | head -n 1)
10+
RAY_PORT=6379
11+
12+
# Get current node name
13+
CURRENT_NODE=$(hostname | cut -d. -f1)
14+
15+
# Get HEAD_NODE_IP
16+
if [ "$SLURM_NODEID" -eq 0 ]; then
17+
# We're on the head node, get our own IP
18+
HEAD_NODE_IP=$(hostname -I | awk '{print $1}')
19+
else
20+
# We're on a worker, resolve the head node's IP using DNS
21+
HEAD_NODE_IP=$(getent hosts "$HEAD_NODE" | awk '{print $1}')
22+
fi
23+
24+
# Set up cleanup function
25+
cleanup() {
26+
if command -v ray &>/dev/null; then
27+
echo "Stopping Ray on node $CURRENT_NODE"
28+
ray stop || true
29+
fi
30+
}
31+
trap cleanup EXIT
32+
33+
# Start Ray based on node role
34+
if [ "$SLURM_NODEID" -eq 0 ]; then
35+
echo "Starting Ray head node on $CURRENT_NODE"
36+
ray start --head --disable-usage-stats --port=$RAY_PORT
37+
echo "Ray head node started at $HEAD_NODE_IP:$RAY_PORT"
38+
else
39+
echo "Waiting for head node to be ready..."
40+
sleep 10
41+
echo "Starting Ray worker on node $CURRENT_NODE (ID: $SLURM_NODEID)"
42+
ray start --disable-usage-stats --address="$HEAD_NODE_IP:$RAY_PORT"
43+
fi
44+
45+
# Ensure Ray cluster is ready
46+
sleep 2
47+
48+
# Only head node runs the training command
49+
if [ "$SLURM_NODEID" -eq 0 ]; then
50+
echo "Starting training process on head node $CURRENT_NODE"
51+
bash -c "$CMD"
52+
else
53+
# Worker nodes just wait for the head to finish
54+
while ray status --address="$HEAD_NODE_IP:$RAY_PORT" &>/dev/null; do
55+
sleep 10
56+
done
57+
fi
58+
59+
echo "Node $CURRENT_NODE: Done"

0 commit comments

Comments
 (0)