@@ -270,8 +270,8 @@ def dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_gr
270270 return dispatch_nd_compute (dp_rank_mapping , dp_size , worker_group , * splitted_args , ** splitted_kwargs )
271271
272272
273- def collect_nd_compute_dataproto (worker_group , output ):
274- output = collect_nd_compute (worker_group , output )
273+ def collect_nd_compute_dataproto (collect_mask : list [ bool ], worker_group , output ):
274+ output = collect_nd_compute (collect_mask , worker_group , output )
275275 import ray
276276
277277 from verl .protocol import DataProto
@@ -294,13 +294,9 @@ def dispatch_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs):
294294 assert len (worker_group ._dispatch_info [mesh_name ]) == worker_group .world_size
295295
296296 dp_rank_mapping = worker_group ._dispatch_info [mesh_name ]
297-
298- # a boolean of whether the dp_rank is used for collect
299- collect_mask = worker_group ._collect_info [mesh_name ]
300-
301297 # perform dispatch
302298 dp_size = max (dp_rank_mapping ) + 1
303- return dispatch_nd_compute_dataproto (dp_rank_mapping , dp_size , worker_group , collect_mask , * args , ** kwargs )
299+ return dispatch_nd_compute_dataproto (dp_rank_mapping , dp_size , worker_group , * args , ** kwargs )
304300
305301
306302def collect_lazy_compute_data_proto (mesh_name , worker_group , * args , ** kwargs ):
@@ -314,8 +310,11 @@ def collect_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs):
314310 if mesh_name not in worker_group ._collect_info :
315311 worker_group ._collect_info [mesh_name ] = worker_group ._query_collect_info (mesh_name )
316312 assert len (worker_group ._collect_info [mesh_name ]) == worker_group .world_size
313+
314+ # a boolean of whether the dp_rank is used for collect
315+ collect_mask = worker_group ._collect_info [mesh_name ]
317316 # perform dispatch
318- return collect_nd_compute_dataproto (worker_group , * args , ** kwargs )
317+ return collect_nd_compute_dataproto (collect_mask , worker_group , * args , ** kwargs )
319318
320319
321320def make_nd_compute_dataproto_dispatch_fn (mesh_name ):
0 commit comments