Skip to content

Commit b96b53a

Browse files
authored
[misc] feat: optimize rearrange_micro_batches (verl-project#4451)
1 parent 80af9db commit b96b53a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

verl/utils/seqlen_balancing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,8 @@ def rearrange_micro_batches(
310310

311311
assert num_micro_batches <= len(seq_len_effective)
312312

313-
workloads = calculate_workload(seq_len_effective)
313+
# note that seq_len_effective is a GPU tensor. We need to make it a list to avoid D2H!
314+
workloads = calculate_workload(seq_len_effective).cpu().tolist()
314315
micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False)
315316

316317
if use_dynamic_bsz_balance:

0 commit comments

Comments
 (0)