Skip to content

Commit 9989e83

Browse files
committed
remove codes
Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
1 parent a513570 commit 9989e83

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

verl/single_controller/base/decorator.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ def dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_gr
270270
return dispatch_nd_compute(dp_rank_mapping, dp_size, worker_group, *splitted_args, **splitted_kwargs)
271271

272272

273-
def collect_nd_compute_dataproto(worker_group, output):
274-
output = collect_nd_compute(worker_group, output)
273+
def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output):
274+
output = collect_nd_compute(collect_mask, worker_group, output)
275275
import ray
276276

277277
from verl.protocol import DataProto
@@ -294,13 +294,9 @@ def dispatch_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs):
294294
assert len(worker_group._dispatch_info[mesh_name]) == worker_group.world_size
295295

296296
dp_rank_mapping = worker_group._dispatch_info[mesh_name]
297-
298-
# a boolean of whether the dp_rank is used for collect
299-
collect_mask = worker_group._collect_info[mesh_name]
300-
301297
# perform dispatch
302298
dp_size = max(dp_rank_mapping) + 1
303-
return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, collect_mask, *args, **kwargs)
299+
return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, *args, **kwargs)
304300

305301

306302
def collect_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs):
@@ -314,8 +310,11 @@ def collect_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs):
314310
if mesh_name not in worker_group._collect_info:
315311
worker_group._collect_info[mesh_name] = worker_group._query_collect_info(mesh_name)
316312
assert len(worker_group._collect_info[mesh_name]) == worker_group.world_size
313+
314+
# a boolean of whether the dp_rank is used for collect
315+
collect_mask = worker_group._collect_info[mesh_name]
317316
# perform dispatch
318-
return collect_nd_compute_dataproto(worker_group, *args, **kwargs)
317+
return collect_nd_compute_dataproto(collect_mask, worker_group, *args, **kwargs)
319318

320319

321320
def make_nd_compute_dataproto_dispatch_fn(mesh_name):

0 commit comments

Comments
 (0)