Skip to content

Commit 5f1ce2b

Browse files
author
liuximeng
committed
[recipe, TransderQueue] fix: remove unused param get_n_samples & update _balance_batch func
1 parent 9f00d21 commit 5f1ce2b

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

recipe/transfer_queue/ray_trainer.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
from verl.utils.metric import reduce_metrics
8282
from verl.utils.rollout_skip import RolloutSkip
8383
from verl.utils.seqlen_balancing import (
84+
calculate_workload,
8485
get_seqlen_balanced_partitions,
8586
log_seqlen_unbalance,
8687
)
@@ -678,7 +679,6 @@ def _validate(self):
678679
data_fields=["input_ids", "uid", "reward_model"],
679680
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
680681
partition_id=f"val_{self.global_steps - 1}",
681-
get_n_samples=False,
682682
task_name="get_data",
683683
)
684684
)
@@ -697,7 +697,6 @@ def _validate(self):
697697
data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields
698698
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
699699
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
700-
get_n_samples=False,
701700
task_name="generate_sequences",
702701
)
703702
)
@@ -727,7 +726,6 @@ def _validate(self):
727726
data_fields=["responses"],
728727
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
729728
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
730-
get_n_samples=False,
731729
task_name="get_response",
732730
)
733731
)
@@ -756,7 +754,6 @@ def _validate(self):
756754
data_fields=compute_reward_fields,
757755
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
758756
partition_id=f"val_{self.global_steps - 1}",
759-
get_n_samples=False,
760757
task_name="compute_reward",
761758
)
762759
)
@@ -780,7 +777,6 @@ def _validate(self):
780777
data_fields=["__num_turns__"],
781778
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
782779
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
783-
get_n_samples=False,
784780
task_name="get_num_turns",
785781
)
786782
)
@@ -794,7 +790,6 @@ def _validate(self):
794790
data_fields=["data_source"],
795791
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
796792
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
797-
get_n_samples=False,
798793
task_name="get_data_source",
799794
)
800795
)
@@ -1098,17 +1093,39 @@ def _stop_profiling(self, do_profile: bool) -> None:
10981093
if self.use_rm:
10991094
self.rm_wg.stop_profile()
11001095

1101-
def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen"):
1096+
def _balance_batch(
1097+
self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False
1098+
):
11021099
"""Reorder the batchmeta on single controller such that each dp rank gets similar total tokens"""
11031100
data = asyncio.run(data_system_client.async_get_data(batch))
11041101

11051102
attention_mask = data["attention_mask"]
11061103
batch_size = attention_mask.shape[0]
1107-
global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
1104+
global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,)
1105+
global_seqlen_lst = calculate_workload(global_seqlen_lst)
11081106
world_size = self.actor_rollout_wg.world_size
1109-
global_partition_lst = get_seqlen_balanced_partitions(
1110-
global_seqlen_lst, k_partitions=world_size, equal_size=True
1111-
)
1107+
if keep_minibatch:
1108+
# Decouple the DP balancing and mini-batching.
1109+
minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size")
1110+
minibatch_num = len(global_seqlen_lst) // minibatch_size
1111+
global_partition_lst = [[] for _ in range(world_size)]
1112+
for i in range(minibatch_num):
1113+
rearrange_minibatch_lst = get_seqlen_balanced_partitions(
1114+
global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size],
1115+
k_partitions=world_size,
1116+
equal_size=True,
1117+
)
1118+
for j, part in enumerate(rearrange_minibatch_lst):
1119+
global_partition_lst[j].extend([x + minibatch_size * i for x in part])
1120+
else:
1121+
global_partition_lst = get_seqlen_balanced_partitions(
1122+
global_seqlen_lst, k_partitions=world_size, equal_size=True
1123+
)
1124+
# Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
1125+
for idx, partition in enumerate(global_partition_lst):
1126+
partition.sort(key=lambda x: (global_seqlen_lst[x], x))
1127+
ordered_partition = partition[::2] + partition[1::2][::-1]
1128+
global_partition_lst[idx] = ordered_partition
11121129
# reorder based on index. The data will be automatically equally partitioned by dispatch function
11131130
global_idx = [j for partition in global_partition_lst for j in partition]
11141131
global_balance_stats = log_seqlen_unbalance(
@@ -1248,7 +1265,6 @@ def fit(self):
12481265
base_get_meta_kwargs = dict(
12491266
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
12501267
partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1
1251-
get_n_samples=False,
12521268
)
12531269

12541270
with marked_timer("start_profile", timing_raw):
@@ -1646,7 +1662,6 @@ def fit(self):
16461662
batch_size=self.config.data.train_batch_size
16471663
* self.config.actor_rollout_ref.rollout.n,
16481664
partition_id=f"train_{self.global_steps - 1}",
1649-
get_n_samples=False,
16501665
task_name="update_actor",
16511666
)
16521667
)
@@ -1672,7 +1687,6 @@ def fit(self):
16721687
data_fields=data_fields,
16731688
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
16741689
partition_id=f"train_{self.global_steps - 1}",
1675-
get_n_samples=False,
16761690
task_name="log_rollout",
16771691
)
16781692
)

0 commit comments

Comments
 (0)