Skip to content

Commit 941207f

Browse files
[rollout] feat: use rollout and validate parallel process (verl-project#4863)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. validate process and rollout process can be paralled in fully-async mode. This can reduce rollouter ilde and improve timing_s/gen speed when in short-long response_length case. This is below profiler: (qwen2.5-Math-7b) <img width="2496" height="1294" alt="image" src="https://github.com/user-attachments/assets/3f028c27-4257-453c-9fcb-196cf2bac5b5" /> <img width="2500" height="1304" alt="image" src="https://github.com/user-attachments/assets/9265e5ec-5bff-46a9-9cc7-e375849de39b" /> <img width="2498" height="1206" alt="image" src="https://github.com/user-attachments/assets/06e81e89-d7f3-4781-8daf-7e06efeb78c5" /> <img width="2518" height="1362" alt="image" src="https://github.com/user-attachments/assets/24304dc3-a6a9-4f5f-a242-1335b4d5e4b5" /> ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. we can use `dapo_7b_math_fsdp2_8_8.sh`, and add `+async_training.parallel_validate_and_rollout=True` command ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. --------- Co-authored-by: Shangwei-Li <lishangwei@mail.ustc.edu.cn>
1 parent 97eac90 commit 941207f

File tree

2 files changed

+72
-14
lines changed

2 files changed

+72
-14
lines changed

verl/experimental/fully_async_policy/fully_async_rollouter.py

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import functools
17+
import multiprocessing
1618
import os
1719
import time
20+
from concurrent.futures import ThreadPoolExecutor
1821
from pprint import pformat
1922

2023
import numpy as np
@@ -168,6 +171,12 @@ def __init__(
168171
self.active_tasks = set()
169172
self.cancel_queue = asyncio.Queue()
170173

174+
cpu_cores = multiprocessing.cpu_count()
175+
# cpu case use cpu_cores; io case use cpu_cores*2
176+
self.validate_executor = ThreadPoolExecutor(max_workers=cpu_cores)
177+
self.parallel_validate_and_rollout = config.async_training.get("parallel_validate_and_rollout", False)
178+
self.validate_task = None
179+
171180
def _init_async_objects(self):
172181
# Initialize asyncio synchronization primitives.
173182
# We let asyncio.Condition create the Lock internally to ensure they share the same Event Loop.
@@ -245,22 +254,71 @@ async def update_param_version(
245254
f",reset staleness_samples to: {self.staleness_samples}"
246255
f",idle_ratio: {idle_ratio}"
247256
)
248-
val_metrics = None
249-
if (
250-
self.val_reward_fn is not None
251-
and self.config.rollout.test_freq > 0
252-
and self.current_param_version % self.config.rollout.test_freq == 0
253-
and self.current_param_version > 0 # don't test here in the initial parameter sync
254-
) or (validate and self.val_reward_fn is not None):
255-
with marked_timer("rollouter/validate_time", timing_raw, color="green"):
256-
val_metrics: dict = self._validate(use_trainer_do_validate)
257-
data = ValidateMetrics(
258-
timing_raw=timing_raw, metrics=val_metrics, global_steps=global_steps, param_version=version
257+
need_validate = (
258+
(
259+
self.val_reward_fn is not None
260+
and self.config.rollout.test_freq > 0
261+
and self.current_param_version % self.config.rollout.test_freq == 0
262+
and self.current_param_version > 0
263+
) # don't test here in the initial parameter sync
264+
or (validate and self.val_reward_fn is not None)
265+
)
266+
print(
267+
f"[FullyAsyncRollouter] need_validate: {need_validate},"
268+
f"parallel_validate_and_rollout: {self.parallel_validate_and_rollout}"
259269
)
260-
await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data))
270+
if not need_validate:
271+
data = ValidateMetrics(
272+
timing_raw=timing_raw, metrics=None, global_steps=global_steps, param_version=version
273+
)
274+
elif need_validate and not self.parallel_validate_and_rollout:
275+
data = self._validate_wrapper(timing_raw, version, global_steps, use_trainer_do_validate)
276+
277+
if not need_validate or not self.parallel_validate_and_rollout:
278+
await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data))
261279

262280
self.version_start_time = time.time()
263281

282+
if need_validate and self.parallel_validate_and_rollout:
283+
if self.validate_task and not self.validate_task.done():
284+
print("[FullyAsyncRollouter] validate_task is running, wait last validate_task to finish")
285+
self.validate_task.get()
286+
self.validate_task = asyncio.create_task(
287+
self.do_validate_async(timing_raw, version, global_steps, use_trainer_do_validate)
288+
)
289+
290+
def _validate_wrapper(
291+
self, timing_raw: dict, version: int, global_steps: int = 0, use_trainer_do_validate: bool = False
292+
):
293+
val_metrics = None
294+
with marked_timer("rollouter/validate_time", timing_raw, color="green"):
295+
val_metrics: dict = self._validate(use_trainer_do_validate)
296+
data = ValidateMetrics(
297+
timing_raw=timing_raw, metrics=val_metrics, global_steps=global_steps, param_version=version
298+
)
299+
return data
300+
301+
async def do_validate_async(
302+
self,
303+
timing_raw: dict,
304+
version: int,
305+
global_steps: int = 0,
306+
use_trainer_do_validate: bool = False,
307+
):
308+
loop = asyncio.get_running_loop()
309+
310+
data = await loop.run_in_executor(
311+
self.validate_executor,
312+
functools.partial(
313+
self._validate_wrapper,
314+
timing_raw=timing_raw,
315+
version=version,
316+
global_steps=global_steps,
317+
use_trainer_do_validate=use_trainer_do_validate,
318+
),
319+
)
320+
await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data))
321+
264322
async def save_checkpoint(self, local_global_step_folder: str):
265323
# WARNING!: Due to the asynchronous nature, there are some in-flight samples
266324
# (pending/cancel/result queue and message queue).

verl/experimental/fully_async_policy/fully_async_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,12 @@ async def init_workers(self):
269269

270270
async def _init_async_rollout_manager(self):
271271
# use async rollout do validate
272+
print(f"[FullyAsyncTrainer] use_trainer_do_validate: {self.config.async_training.use_trainer_do_validate}")
272273
if self.config.async_training.use_trainer_do_validate:
273-
print(f"[FullyAsyncTrainer] use_trainer_do_validate: {self.config.async_training.use_trainer_do_validate}")
274274
assert self.config.actor_rollout_ref.rollout.mode == "async"
275275
self.async_rollout_mode = True
276276
print("[FullyAsyncTrainer] Init async rollout manager")
277-
from recipe.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager
277+
from verl.experimental.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager
278278

279279
self.async_rollout_manager = await FullyAsyncAgentLoopManager.create(
280280
config=self.config, worker_group=self.actor_rollout_wg

0 commit comments

Comments
 (0)