Skip to content

Commit 0c06358

Browse files
[data] feat: TransferQueue - remove redundant data collect for both TQ and DataProto (verl-project#4618)
### What does this PR do? This PR optimizes the dataflow between the single-controller (RayPPOTrainer) and worker processes. By applying the result filtering (de-duplication) logic from the controller to the individual workers in advance, we significantly reduce redundant data transmission. In the current architecture, worker functions use the @register decorator to manage data dispatching and collection. ```python3 # in megatron_workers.py @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="update_actor", logger=logger) @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): assert self._is_actor if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) log_gpu_memory_usage("After load actor params and grad during update_actor", logger=logger) if self._is_offload_optimizer: load_megatron_optimizer(self.actor_optimizer) log_gpu_memory_usage("After load actor optimizer during update_actor", logger=logger) ``` Specifically, `make_nd_compute_dataproto_dispatch_fn` is used to 1. Dispatch: Shard and distribute input data from the controller to workers. 2. Collect: Gather results back from all workers and de-duplicate them (usually keeping only one copy per DP group) using a collect_mask. ```python3 def make_nd_compute_dataproto_dispatch_fn(mesh_name): return { "dispatch_fn": partial(dispatch_lazy_compute_data_proto, mesh_name), "collect_fn": partial(collect_lazy_compute_data_proto, mesh_name), } ``` ```python3 def collect_nd_compute(collect_mask: list[bool], worker_group, output): from verl.single_controller.base.worker_group import WorkerGroup assert isinstance(worker_group, WorkerGroup) assert len(output) == worker_group.world_size output_in_dp = [] for global_rank in range(worker_group.world_size): collect_dp_rank = collect_mask[global_rank] if collect_dp_rank: output_in_dp.append(output[global_rank]) return output_in_dp ``` Previously, the de-duplication process happened entirely on the controller side. - Redundant Transmission: Every worker sent its full output back to the controller, regardless of whether that data would be kept or discarded. - Overhead: For large-scale LLM training (large DP groups), this led to massive, unnecessary network traffic or repeated put operations to the TransferQueue. - Bottleneck: The controller became a bottleneck as it had to receive and process redundant data chunks before applying the mask. In this PR, we shifts the filtering logic "left" (to the worker side). Now each worker automatically determine whether to return real data or return an empty obj to the single-controller according to the `collect_mask` in advance, thus reducing data transfer overhead. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com> Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com> Co-authored-by: Jianjun Zhong <87791082+jianjunzhong@users.noreply.github.com> Co-authored-by: jianjunzhong <jianjunzhong@foxmail.com>
1 parent b19b749 commit 0c06358

File tree

3 files changed

+111
-10
lines changed

3 files changed

+111
-10
lines changed

verl/single_controller/base/decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki
447447
_check_execute_mode(execute_mode=execute_mode)
448448

449449
def decorator(func):
450-
func = tqbridge()(func)
450+
func = tqbridge(dispatch_mode=dispatch_mode)(func)
451451

452452
@wraps(func)
453453
def inner(*args, **kwargs):

verl/single_controller/base/worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def _query_dispatch_info(self, mesh_name: str):
117117

118118
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
119119
def _query_collect_info(self, mesh_name: str):
120+
return self.query_collect_info(mesh_name)
121+
122+
def query_collect_info(self, mesh_name: str):
120123
"""Query the collect info for a given mesh name.
121124
122125
Args:

verl/utils/transferqueue_utils.py

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import functools
1617
import inspect
1718
import logging
1819
import os
1920
import threading
2021
from functools import wraps
21-
from typing import Any, Callable
22+
from typing import TYPE_CHECKING, Any, Callable
23+
24+
if TYPE_CHECKING:
25+
from verl.single_controller.base.decorator import Dispatch
2226

2327
from tensordict import TensorDict
2428

@@ -144,7 +148,95 @@ def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", fun
144148
return updated_batch_meta
145149

146150

147-
def tqbridge(put_data: bool = True):
151+
def _compute_need_collect(dispatch_mode: "dict | Dispatch", args: list) -> bool:
152+
"""Compute whether data collection is needed for the current worker.
153+
154+
This function determines whether the current worker should collect data based on
155+
the dispatch mode configuration and worker parameters. It's used to optimize
156+
distributed data collection by ensuring only the appropriate rank collects data.
157+
158+
Args:
159+
dispatch_mode: Controls data collection logic for the current worker. Can be None,
160+
a Dispatch instance, or a dict with 'collect_fn' key. If None or Dispatch,
161+
always returns True (current worker should collect). If dict, checks
162+
collect_fn for lazy compute optimization.
163+
args: List of arguments passed to the function. Should contain a Worker instance
164+
as the first argument when using lazy compute mode.
165+
166+
Returns:
167+
bool: True if data collection is needed, False otherwise.
168+
169+
Note:
170+
Only checks worker attributes when dispatch_mode is a dict with 'collect_fn',
171+
the collect_fn is 'collect_lazy_compute_data_proto', and args[0] is a Worker.
172+
Otherwise, returns True. For the lazy compute case, checks the worker's
173+
data parallel rank for the mesh specified in collect_fn.args[0] to determine
174+
if this worker should collect data.
175+
"""
176+
from verl.single_controller.base.decorator import Dispatch
177+
from verl.single_controller.base.worker import Worker
178+
179+
if dispatch_mode is None or isinstance(dispatch_mode, Dispatch):
180+
return True
181+
182+
assert "collect_fn" in dispatch_mode.keys(), "collect_fn should be in dispatch_mode."
183+
184+
collect_fn = dispatch_mode["collect_fn"]
185+
186+
# Check if collect_fn is a functools.partial and handle gracefully
187+
if isinstance(collect_fn, functools.partial):
188+
collect_fn_name = collect_fn.func.__name__
189+
if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker):
190+
return True
191+
192+
collect_mesh_name = collect_fn.args[0] if collect_fn.args else None
193+
if collect_mesh_name is None:
194+
return True
195+
196+
return args[0].query_collect_info(collect_mesh_name)
197+
else:
198+
# If collect_fn is not a partial, we can't extract mesh_name information
199+
# Fall back to default behavior (collect data)
200+
return True
201+
202+
203+
def _postprocess_common(output, put_data, need_collect):
204+
"""Common post-processing logic for function outputs in TransferQueue bridge.
205+
206+
This function handles the final return value based on whether data should be
207+
put into storage (put_data) and whether collection is needed (need_collect).
208+
It ensures proper return types based on the execution context.
209+
210+
Args:
211+
output: The original output from the decorated function. Can be any type.
212+
put_data: bool, indicating whether the output should be put into TransferQueue.
213+
If True, output will be put to TQ and return the corresponding BatchMeta;
214+
if False, output will not be put into TQ.
215+
need_collect: bool, indicating whether this process needs to collect data.
216+
If False, the output will be replaced by an empty BatchMeta or DataProto
217+
to avoid redundant communication.
218+
219+
Returns:
220+
- BatchMeta.empty(): When put_data=True but need_collect=False, indicating
221+
no data should be stored but BatchMeta structure is expected.
222+
- DataProto(): When put_data=False, need_collect=False, and output is DataProto,
223+
returning an empty DataProto.
224+
- output: In all other cases, returns the original output unchanged.
225+
226+
Note:
227+
This function is used in the tqbridge decorator to normalize return values
228+
across different execution paths and avoid redundant data operations in
229+
distributed scenarios.
230+
"""
231+
if put_data and not need_collect:
232+
return BatchMeta.empty()
233+
elif not put_data and not need_collect and isinstance(output, DataProto):
234+
return DataProto()
235+
else:
236+
return output
237+
238+
239+
def tqbridge(dispatch_mode: "dict | Dispatch" = None, put_data: bool = True):
148240
"""Creates a decorator for bridging BatchMeta and DataProto.
149241
150242
This decorator automatically handles conversions between `BatchMeta` and
@@ -155,6 +247,9 @@ def tqbridge(put_data: bool = True):
155247
simply calls the original function as-is).
156248
157249
Args:
250+
dispatch_mode: Controls data collection behavior for the current worker. Passed to
251+
_compute_need_collect to determine if current worker should collect data.
252+
If None, _compute_need_collect will return True to fallback default logics.
158253
put_data: Whether put the DataProto into Storage after func return.
159254
If True, after function execution, the output result will be
160255
updated to `BatchMeta` and `BatchMeta` will be returned;
@@ -181,11 +276,11 @@ def inner(*args, **kwargs):
181276
args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]
182277
kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()}
183278
output = func(*args, **kwargs)
184-
if put_data:
279+
need_collect = _compute_need_collect(dispatch_mode, args)
280+
if put_data and need_collect:
185281
updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__)
186282
return updated_batch_meta
187-
else:
188-
return output
283+
return _postprocess_common(output, put_data, need_collect)
189284

190285
@wraps(func)
191286
async def async_inner(*args, **kwargs):
@@ -203,18 +298,21 @@ async def async_inner(*args, **kwargs):
203298
for k, v in kwargs.items()
204299
}
205300
output = await func(*args, **kwargs)
206-
if put_data:
301+
need_collect = _compute_need_collect(dispatch_mode, args)
302+
if put_data and need_collect:
207303
updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__)
208304
return updated_batchmeta
209-
return output
305+
return _postprocess_common(output, put_data, need_collect)
210306

211307
@wraps(func)
212308
def dummy_inner(*args, **kwargs):
213-
return func(*args, **kwargs)
309+
output = func(*args, **kwargs)
310+
return output
214311

215312
@wraps(func)
216313
async def dummy_async_inner(*args, **kwargs):
217-
return await func(*args, **kwargs)
314+
output = await func(*args, **kwargs)
315+
return output
218316

219317
wrapper_inner = inner if is_transferqueue_enabled else dummy_inner
220318
wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner

0 commit comments

Comments
 (0)