We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 80af9db commit b96b53aCopy full SHA for b96b53a
verl/utils/seqlen_balancing.py
@@ -310,7 +310,8 @@ def rearrange_micro_batches(
310
311
assert num_micro_batches <= len(seq_len_effective)
312
313
- workloads = calculate_workload(seq_len_effective)
+ # note that seq_len_effective is a GPU tensor. We need to make it a list to avoid D2H!
314
+ workloads = calculate_workload(seq_len_effective).cpu().tolist()
315
micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False)
316
317
if use_dynamic_bsz_balance:
0 commit comments