Skip to content

Commit b5d6a30

Browse files
[rollout,vllm] feat: disable sleep mode in fully-async mode (verl-project#4521)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. - 在fully-async 模式下,添加enable_sleep_mode 参数支持,可以使用虚拟显存功能,可充分利用显存,减少OOM 现象 - 默认下是True,与原来保持一致 ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. 可参考[recipe/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh] 脚本进行使用,新增`+actor_rollout_ref.rollout.enable_sleep_mode=False\` 命令即可; ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 379f296 commit b5d6a30

File tree

3 files changed

+201
-1
lines changed

3 files changed

+201
-1
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#!/usr/bin/env bash
2+
set -xeuo pipefail
3+
4+
project_name='DAPO-Qwen3-30B-A3B-Base-Async'
5+
exp_name='Fsdp2-tp4sp4'
6+
7+
# Ray
8+
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
9+
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
10+
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
11+
# Paths
12+
DATA_PATH=${RAY_DATA_HOME:-"${HOME}/verl"}
13+
DATA_PATH=${DATA_PATH:-"/mnt/bn/${BYTENAS}"}
14+
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
15+
MODEL_PATH=${MODEL_PATH:-"${DATA_PATH}/shared/models/Qwen3-30B-A3B-Base"}
16+
CKPTS_DIR=${CKPTS_DIR:-"${DATA_PATH}/ckpts/${project_name}/${exp_name}"}
17+
TRAIN_FILE=${TRAIN_FILE:-"${DATA_PATH}/shared/data/dapo-math/dapo-math-17k.parquet"}
18+
TEST_FILE=${TEST_FILE:-"${DATA_PATH}/shared/data/dapo-math/aime-2024.parquet"}
19+
20+
21+
rollout_mode="async"
22+
rollout_name="vllm" # sglang or vllm
23+
if [ "$rollout_mode" = "async" ]; then
24+
export VLLM_USE_V1=1
25+
return_raw_chat="True"
26+
fi
27+
28+
# Algorithm parameters
29+
adv_estimator=grpo
30+
31+
use_kl_in_reward=False
32+
kl_coef=0.0
33+
use_kl_loss=False
34+
kl_loss_coef=0.0
35+
36+
clip_ratio_low=0.2
37+
clip_ratio_high=0.28
38+
39+
# Response length parameters
40+
max_prompt_length=$((1024 * 2))
41+
max_response_length=$((1024 * 20))
42+
enable_overlong_buffer=True
43+
overlong_buffer_len=$((1024 * 4))
44+
overlong_penalty_factor=1.0
45+
46+
# Training parameters
47+
loss_agg_mode="token-mean"
48+
enable_filter_groups=True
49+
filter_groups_metric=acc
50+
max_num_gen_batches=10
51+
52+
# Algorithm
53+
temperature=1.0
54+
top_p=1.0
55+
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
56+
val_top_p=0.7
57+
58+
59+
NNODES=${NNODES:-4}
60+
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
61+
62+
# Fully async specific parameters
63+
n_gpus_rollout=8
64+
n_gpus_training=8
65+
n_nodes_rollout=2
66+
n_nodes_train=2 # $((NNODES - n_nodes_rollout))
67+
68+
train_bsz=512
69+
train_prompt_bsz=0
70+
gen_prompt_bsz=1
71+
n_resp_per_prompt=16
72+
train_prompt_mini_bsz=32
73+
total_rollout_steps=$(((train_bsz * 400)))
74+
test_freq=25
75+
staleness_threshold=0.6 # 0 0.3 1
76+
require_batches=1
77+
total_train_gpus=$((n_gpus_training * n_nodes_train))
78+
total_rollout_gpus=$((n_gpus_rollout * n_nodes_rollout))
79+
trigger_parameter_sync_step=$((train_bsz / ( train_prompt_mini_bsz * require_batches))) # 8 16 32
80+
partial_rollout=True
81+
enforce_eager=False
82+
nccl_timeout=72000
83+
enable_sleep_mode=False
84+
85+
# Performance Related Parameter
86+
sp_size=4
87+
use_dynamic_bsz=True
88+
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
89+
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
90+
ref_offload=True
91+
actor_offload=False
92+
gen_tp=4
93+
fsdp_size=-1
94+
95+
96+
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
97+
--working-dir "${WORKING_DIR}" \
98+
--address "${RAY_ADDRESS}" \
99+
-- python3 -m recipe.fully_async_policy.fully_async_main \
100+
--config-path=config \
101+
--config-name='fully_async_dapo_trainer.yaml' \
102+
data.train_files="${TRAIN_FILE}" \
103+
data.val_files="${TEST_FILE}" \
104+
data.prompt_key=prompt \
105+
data.truncation='left' \
106+
actor_rollout_ref.actor.strategy=fsdp \
107+
critic.strategy=fsdp \
108+
data.max_prompt_length=${max_prompt_length} \
109+
data.max_response_length=${max_response_length} \
110+
data.train_batch_size=${train_prompt_bsz} \
111+
data.gen_batch_size=${gen_prompt_bsz} \
112+
data.return_raw_chat=${return_raw_chat} \
113+
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
114+
algorithm.adv_estimator=${adv_estimator} \
115+
algorithm.use_kl_in_reward=${use_kl_in_reward} \
116+
algorithm.kl_ctrl.kl_coef=${kl_coef} \
117+
actor_rollout_ref.rollout.calculate_log_probs=True \
118+
actor_rollout_ref.nccl_timeout=${nccl_timeout} \
119+
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
120+
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
121+
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
122+
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
123+
actor_rollout_ref.actor.clip_ratio_c=10.0 \
124+
actor_rollout_ref.model.use_remove_padding=True \
125+
actor_rollout_ref.hybrid_engine=False \
126+
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
127+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
128+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
129+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
130+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
131+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
132+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
133+
actor_rollout_ref.model.path="${MODEL_PATH}" \
134+
actor_rollout_ref.actor.optim.lr=1e-6 \
135+
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
136+
actor_rollout_ref.actor.optim.weight_decay=0.1 \
137+
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
138+
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
139+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
140+
actor_rollout_ref.actor.entropy_coeff=0 \
141+
actor_rollout_ref.actor.grad_clip=1.0 \
142+
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
143+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
144+
actor_rollout_ref.rollout.gpu_memory_utilization=0.50 \
145+
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
146+
actor_rollout_ref.rollout.enable_chunked_prefill=True \
147+
+actor_rollout_ref.rollout.enable_sleep_mode=${enable_sleep_mode} \
148+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
149+
actor_rollout_ref.rollout.enforce_eager=${enforce_eager} \
150+
actor_rollout_ref.rollout.temperature=${temperature} \
151+
actor_rollout_ref.rollout.top_p=${top_p} \
152+
actor_rollout_ref.rollout.top_k=${top_k} \
153+
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
154+
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
155+
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
156+
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
157+
actor_rollout_ref.rollout.val_kwargs.n=1 \
158+
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
159+
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
160+
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
161+
actor_rollout_ref.rollout.name=${rollout_name} \
162+
actor_rollout_ref.rollout.mode=${rollout_mode} \
163+
reward_model.reward_manager=dapo \
164+
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
165+
reward_model.overlong_buffer.len=${overlong_buffer_len} \
166+
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
167+
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
168+
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
169+
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
170+
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
171+
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
172+
trainer.logger=['console','wandb'] \
173+
trainer.project_name="${project_name}" \
174+
trainer.experiment_name="${exp_name}-i${total_rollout_gpus}_t${total_train_gpus}_s${staleness_threshold}" \
175+
trainer.val_before_train=True \
176+
trainer.test_freq="${test_freq}" \
177+
trainer.save_freq=-1 \
178+
trainer.default_local_dir="${CKPTS_DIR}" \
179+
trainer.resume_mode=auto \
180+
trainer.nnodes="${n_nodes_train}" \
181+
trainer.n_gpus_per_node="${n_gpus_training}" \
182+
rollout.nnodes="${n_nodes_rollout}" \
183+
rollout.n_gpus_per_node="${n_gpus_rollout}" \
184+
rollout.total_rollout_steps="${total_rollout_steps}" \
185+
rollout.test_freq=${test_freq} \
186+
rollout.total_epochs=10 \
187+
async_training.require_batches=${require_batches} \
188+
async_training.staleness_threshold="${staleness_threshold}" \
189+
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
190+
async_training.partial_rollout="${partial_rollout}" \
191+
async_training.use_rollout_log_probs=True

verl/workers/config/rollout.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ class RolloutConfig(BaseConfig):
207207

208208
enable_rollout_routing_replay: bool = False
209209

210+
enable_sleep_mode: bool = True
211+
210212
def __post_init__(self):
211213
"""Validate the rollout config"""
212214
if self.expert_parallel_size > 1:

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,13 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
255255
max_new_tokens=self.config.response_length,
256256
)
257257
logger.info(f"override_generation_config: {override_generation_config}")
258+
259+
logger.info(f"enable_sleep_mode: {self.config.enable_sleep_mode}")
260+
if not self.config.enable_sleep_mode:
261+
from verl.utils.device import set_expandable_segments
262+
263+
set_expandable_segments(True)
264+
258265
quantization = self.config.quantization
259266
if quantization is not None:
260267
if quantization == "fp8":
@@ -280,7 +287,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
280287
"enable_chunked_prefill": self.config.enable_chunked_prefill,
281288
"max_num_batched_tokens": self.config.max_num_batched_tokens,
282289
"enable_prefix_caching": self.config.enable_prefix_caching,
283-
"enable_sleep_mode": True,
290+
"enable_sleep_mode": self.config.enable_sleep_mode,
284291
"disable_custom_all_reduce": True,
285292
"enforce_eager": self.config.enforce_eager,
286293
"gpu_memory_utilization": self.config.gpu_memory_utilization,

0 commit comments

Comments
 (0)