Skip to content

Commit 88cac0a

Browse files
committed
update code
1 parent 29e9219 commit 88cac0a

File tree

5 files changed

+69
-45
lines changed

5 files changed

+69
-45
lines changed

docs/en/platform_support/amd_tutorial.md

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# AMD
22

3-
⚠️ If you encounter problems on AMD instinct, feel free to reach out [Yusheng Su](https://yushengsu-thu.github.io/).
3+
⚠️ If you encounter problems on AMD instinct, feel free to reach out [Yusheng Su](https://yushengsu-thu.github.io/).
44

55

66
## Introduction
@@ -12,7 +12,7 @@ If you are running slime on AMD's Instinct, please refer to the following materi
1212

1313
## Docker
1414

15-
You can download the prebuilt image from DockerHub: [rlsys/slime](https://hub.docker.com/r/rlsys/slime/tags).
15+
You can download the prebuilt image from DockerHub: [rlsys/slime](https://hub.docker.com/r/rlsys/slime/tags).
1616
```bash
1717
docker pull rlsys/slime:latest
1818
```
@@ -72,32 +72,31 @@ source scripts/models/qwen3-4B.sh
7272
MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}')
7373
PYTHONPATH=${MEGATRON_LM_PATH} python tools/convert_hf_to_torch_dist.py \
7474
${MODEL_ARGS[@]} \
75+
--no-gradient-accumulation-fusion \
7576
--hf-checkpoint model/Qwen3-4B \
7677
--save model/Qwen3-4B_torch_dist
7778
```
7879

79-
Note: You might encounter some issue in the current model convert script on AMD GPUs. You can go [here](https://huggingface.co/zyzshishui0627/models) to dowload the converted models.
80+
Note: We implemented a dedicated AMD conversion script that forces a CPU-only conversion workflow using the Gloo backend to bypass hardware-specific issues. A GPU-based script for ROCm is currently in development.
8081

8182
⚠️ If you encounter an issue where slime cannot be found, please run `pip install -e .` in the slime directory.
8283

8384

8485
### Example: Qwen3-4B
8586

8687
We provide examples to use [Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B), please refer to:
87-
- [Example: Qwen3-4B Model](scripts/run-qwen3-4B-amd.sh): Just run `scripts/run-qwen3-4B-amd.sh`
88+
- [Example: Qwen3-4B Model](../../../scripts/run-qwen3-4B-amd.sh): Just run `scripts/run-qwen3-4B-amd.sh`
8889

89-
⚠️ TODO: The [ROCm-version torch_memory_saver](https://github.com/yushengsu-thu/torch_memory_saver.git) does not seem to clear memory properly; thus, we set `--sglang-mem-fraction-static` as `0.4` currently. We will continue investigating and focus on ROCm's virtual memory management for further modifications.
90-
91-
⚠️ TODO: ROCM seems to not support `apex` yet. Thus, we need to disable `--no-gradient-accumulation-fusion` currently. We will continue investigating how to enable this.
90+
⚠️ TODO: ROCM seems to not support `apex` yet. Thus, we need to disable gradient accumulation fusionby adding the `--no-gradient-accumulation-fusion` flag in the training script currently. We will continue investigating how to enable this.
9291

9392
⚠️ Note: The main difference between ROCm's training script and NVIDIA's script is that you need to set `RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES` and `HIP_VISIBLE_DEVICES` for ray to function properly on AMD GPUs.
9493

95-
- We show the training script below:
94+
- We show the training script below:
9695

9796
```bash
9897
#!/bin/bash
9998

100-
####clear before training
99+
# for rerun the task
101100
pkill -9 sglang
102101
sleep 3
103102
ray stop --force
@@ -107,34 +106,35 @@ sleep 3
107106
pkill -9 ray
108107
pkill -9 python
109108

109+
110110
set -euxo pipefail
111111

112-
### ROCm Support ###
113-
SLIME_DIR="/home/yushensu/projects/slime" # Need to change to your own path
114-
export SLIME_DIR=$SLIME_DIR
115112

116-
MODEL_DIR="/home/yushensu/projects/model" # Need to change to your own path
117-
export MODEL_DIR=$MODEL_DIR
113+
### AMD Support ###
114+
SLIME_DIR="${SLIME_DIR:-/home/yushensu/projects/slime}" # Default path if not set in environment
115+
export SLIME_DIR
116+
117+
MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment
118+
export MODEL_DIR
118119

119-
DATA_DIR="/home/yushensu/projects/data" # Need to change to your own path
120-
export DATA_DIR=$DATA_DIR
120+
DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment
121+
export DATA_DIR
121122

122123
# For AMD GPU
123124
export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1
124125
export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use
125126
####################
126127

128+
127129
# will prevent ray from buffering stdout/stderr
128130
export PYTHONBUFFERED=16
129131

130-
131132
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
132133
source "${SCRIPT_DIR}/models/qwen3-4B.sh"
133134

134135
CKPT_ARGS=(
135136
--hf-checkpoint ${MODEL_DIR}/Qwen3-4B
136-
#--hf-checkpoint /root/Qwen3-4B-FP8
137-
--ref-load ${MODEL_DIR}/Qwen3-4B_torch
137+
--ref-load ${MODEL_DIR}/Qwen3-4B_torch_dist
138138
--load ${MODEL_DIR}/Qwen3-4B_slime/
139139
--save ${MODEL_DIR}/Qwen3-4B_slime/
140140
--save-interval 20
@@ -146,9 +146,7 @@ ROLLOUT_ARGS=(
146146
--label-key label
147147
--apply-chat-template
148148
--rollout-shuffle
149-
150149
--rm-type deepscaler
151-
152150
--num-rollout 3000
153151
--rollout-batch-size 32
154152
--n-samples-per-prompt 8
@@ -204,24 +202,16 @@ OPTIMIZER_ARGS=(
204202
)
205203

206204
WANDB_ARGS=(
207-
#--use-wandb
205+
# --use-wandb
208206
# --wandb-project slime-dev
209207
# --wandb-group qwen3-4B-test
210208
# --wandb-key ${WANDB_KEY}
211209
)
212210

213-
### AMD Support ###
214-
# Need to fix some issue with torch_memory_saver in rocm to support larger --sglang-mem-fraction-static
215-
# SGLANG_ARGS=(
216-
# --rollout-num-gpus-per-engine 2
217-
# --sglang-mem-fraction-static 0.7
218-
# )
219211
SGLANG_ARGS=(
220212
--rollout-num-gpus-per-engine 2
221-
--sglang-mem-fraction-static 0.4
213+
--sglang-mem-fraction-static 0.7
222214
)
223-
####################
224-
225215

226216
MISC_ARGS=(
227217
# default dropout in megatron is 0.1
@@ -242,14 +232,16 @@ MISC_ARGS=(
242232
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
243233

244234
NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l)
245-
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats
235+
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
236+
246237

238+
# "PYTHONPATH": "/workspace/Megatron-LM/",
239+
MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}')
247240

248-
# "PYTHONPATH": "$(dirname $(python3 -c 'import megatron.core; print(megatron.core.__file__)'))"
249241
ray job submit --address="http://127.0.0.1:8265" \
250242
--runtime-env-json='{
251243
"env_vars": {
252-
"PYTHONPATH": "'${SLIME_DIR}'/Megatron-LM/",
244+
"PYTHONPATH": "/workspace/Megatron-LM/",
253245
"CUDA_DEVICE_MAX_CONNECTIONS": "1"
254246
}
255247
}' \

scripts/run-qwen3-4B-amd.sh

100644100755
Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,12 @@ export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can ch
3333
# will prevent ray from buffering stdout/stderr
3434
export PYTHONBUFFERED=16
3535

36-
# Current Model convert script on AMD GPU has some issue, please download the converted model from here: https://huggingface.co/zyzshishui0627/models
37-
3836
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
3937
source "${SCRIPT_DIR}/models/qwen3-4B.sh"
4038

4139
CKPT_ARGS=(
4240
--hf-checkpoint ${MODEL_DIR}/Qwen3-4B
43-
#--hf-checkpoint /root/Qwen3-4B-FP8
4441
--ref-load ${MODEL_DIR}/Qwen3-4B_torch_dist
45-
# --ref-load ${MODEL_DIR}/Qwen3-4B_torch_dist_amd_new
4642
--load ${MODEL_DIR}/Qwen3-4B_slime/
4743
--save ${MODEL_DIR}/Qwen3-4B_slime/
4844
--save-interval 20
@@ -116,12 +112,6 @@ WANDB_ARGS=(
116112
# --wandb-key ${WANDB_KEY}
117113
)
118114

119-
### AMD Support ###
120-
# Need to fix some issue with torch_memory_saver in rocm to support larger --sglang-mem-fraction-static
121-
# SGLANG_ARGS=(
122-
# --rollout-num-gpus-per-engine 2
123-
# --sglang-mem-fraction-static 0.7
124-
# )
125115
SGLANG_ARGS=(
126116
--rollout-num-gpus-per-engine 2
127117
--sglang-mem-fraction-static 0.7
@@ -172,4 +162,4 @@ ray job submit --address="http://127.0.0.1:8265" \
172162
${PERF_ARGS[@]} \
173163
${EVAL_ARGS[@]} \
174164
${SGLANG_ARGS[@]} \
175-
${MISC_ARGS[@]}
165+
${MISC_ARGS[@]}

slime/backends/megatron_utils/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,14 @@ def initialize_model_and_optimizer(
714714
tuple[list[DDP], MegatronOptimizer, OptimizerParamScheduler, int]:
715715
DDP-wrapped model chunks, optimizer, scheduler, and iteration index.
716716
"""
717+
718+
if torch.version.hip:
719+
import megatron.core.dist_checkpointing.strategies.filesystem_async as filesystem_async_module
720+
from slime.utils.rocm_checkpoint_writer import ROCmFileSystemWriterAsync
721+
722+
filesystem_async_module.FileSystemWriterAsync = ROCmFileSystemWriterAsync
723+
print("[ROCm] Applied FileSystemWriterAsync patch for HIP compatibility")
724+
717725
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(args, role)
718726
model[0].role = role
719727
clear_memory()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync
3+
4+
5+
class ROCmFileSystemWriterAsync(FileSystemWriterAsync):
6+
"""
7+
FileSystemWriterAsync wrapper for ROCm compatibility.
8+
9+
On ROCm/HIP, using non_blocking=True causes tensors to be stored in pinned memory,
10+
which triggers segmentation faults when forking subprocesses afterward.
11+
"""
12+
13+
@staticmethod
14+
def preload_tensors(*args, **kwargs):
15+
# Change argument non_blocking to False on HIP platform
16+
# The tensors will be stored in pinned memory if non_blocking=True
17+
# Currently on the ROCm platform, forking a subprocess afterward
18+
# with pinned_memory=True will trigger segmentation fault
19+
if torch.version.hip:
20+
print("HIP/ROCm detected: setting non_blocking=False in preload_tensors")
21+
if "non_blocking" in kwargs:
22+
kwargs["non_blocking"] = False
23+
elif len(args) > 1 and isinstance(args[-1], bool):
24+
# non_blocking is typically the last argument
25+
args = args[:-1] + (False,)
26+
27+
return FileSystemWriterAsync.preload_tensors(*args, **kwargs)

tools/convert_hf_to_torch_dist.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def ceildiv(a, b):
7272

7373

7474
def main():
75+
if torch.version.hip:
76+
import megatron.core.dist_checkpointing.strategies.filesystem_async as filesystem_async_module
77+
from slime.utils.rocm_checkpoint_writer import ROCmFileSystemWriterAsync
78+
79+
filesystem_async_module.FileSystemWriterAsync = ROCmFileSystemWriterAsync
80+
print("[ROCm] Applied FileSystemWriterAsync patch for HIP compatibility")
81+
7582
configure_logger()
7683

7784
# Initialize distributed environment

0 commit comments

Comments
 (0)