8181from verl .utils .metric import reduce_metrics
8282from verl .utils .rollout_skip import RolloutSkip
8383from verl .utils .seqlen_balancing import (
84+ calculate_workload ,
8485 get_seqlen_balanced_partitions ,
8586 log_seqlen_unbalance ,
8687)
@@ -678,7 +679,6 @@ def _validate(self):
678679 data_fields = ["input_ids" , "uid" , "reward_model" ],
679680 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
680681 partition_id = f"val_{ self .global_steps - 1 } " ,
681- get_n_samples = False ,
682682 task_name = "get_data" ,
683683 )
684684 )
@@ -697,7 +697,6 @@ def _validate(self):
697697 data_fields = list (test_batch .keys ()), # TODO: (TQ) Get metadata by specified fields
698698 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
699699 partition_id = f"val_{ self .global_steps - 1 } " , # self.global_steps start from 1
700- get_n_samples = False ,
701700 task_name = "generate_sequences" ,
702701 )
703702 )
@@ -727,7 +726,6 @@ def _validate(self):
727726 data_fields = ["responses" ],
728727 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
729728 partition_id = f"val_{ self .global_steps - 1 } " , # self.global_steps start from 1
730- get_n_samples = False ,
731729 task_name = "get_response" ,
732730 )
733731 )
@@ -756,7 +754,6 @@ def _validate(self):
756754 data_fields = compute_reward_fields ,
757755 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
758756 partition_id = f"val_{ self .global_steps - 1 } " ,
759- get_n_samples = False ,
760757 task_name = "compute_reward" ,
761758 )
762759 )
@@ -780,7 +777,6 @@ def _validate(self):
780777 data_fields = ["__num_turns__" ],
781778 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
782779 partition_id = f"val_{ self .global_steps - 1 } " , # self.global_steps start from 1
783- get_n_samples = False ,
784780 task_name = "get_num_turns" ,
785781 )
786782 )
@@ -794,7 +790,6 @@ def _validate(self):
794790 data_fields = ["data_source" ],
795791 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
796792 partition_id = f"val_{ self .global_steps - 1 } " , # self.global_steps start from 1
797- get_n_samples = False ,
798793 task_name = "get_data_source" ,
799794 )
800795 )
@@ -1098,17 +1093,39 @@ def _stop_profiling(self, do_profile: bool) -> None:
10981093 if self .use_rm :
10991094 self .rm_wg .stop_profile ()
11001095
1101- def _balance_batch (self , batch : BatchMeta , data_system_client , metrics , logging_prefix = "global_seqlen" ):
1096+ def _balance_batch (
1097+ self , batch : BatchMeta , data_system_client , metrics , logging_prefix = "global_seqlen" , keep_minibatch = False
1098+ ):
11021099 """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens"""
11031100 data = asyncio .run (data_system_client .async_get_data (batch ))
11041101
11051102 attention_mask = data ["attention_mask" ]
11061103 batch_size = attention_mask .shape [0 ]
1107- global_seqlen_lst = data ["attention_mask" ].view (batch_size , - 1 ).sum (- 1 ).tolist () # (train_batch_size,)
1104+ global_seqlen_lst = data ["attention_mask" ].view (batch_size , - 1 ).sum (- 1 ) # (train_batch_size,)
1105+ global_seqlen_lst = calculate_workload (global_seqlen_lst )
11081106 world_size = self .actor_rollout_wg .world_size
1109- global_partition_lst = get_seqlen_balanced_partitions (
1110- global_seqlen_lst , k_partitions = world_size , equal_size = True
1111- )
1107+ if keep_minibatch :
1108+ # Decouple the DP balancing and mini-batching.
1109+ minibatch_size = self .config .actor_rollout_ref .actor .get ("ppo_mini_batch_size" )
1110+ minibatch_num = len (global_seqlen_lst ) // minibatch_size
1111+ global_partition_lst = [[] for _ in range (world_size )]
1112+ for i in range (minibatch_num ):
1113+ rearrange_minibatch_lst = get_seqlen_balanced_partitions (
1114+ global_seqlen_lst [i * minibatch_size : (i + 1 ) * minibatch_size ],
1115+ k_partitions = world_size ,
1116+ equal_size = True ,
1117+ )
1118+ for j , part in enumerate (rearrange_minibatch_lst ):
1119+ global_partition_lst [j ].extend ([x + minibatch_size * i for x in part ])
1120+ else :
1121+ global_partition_lst = get_seqlen_balanced_partitions (
1122+ global_seqlen_lst , k_partitions = world_size , equal_size = True
1123+ )
1124+ # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
1125+ for idx , partition in enumerate (global_partition_lst ):
1126+ partition .sort (key = lambda x : (global_seqlen_lst [x ], x ))
1127+ ordered_partition = partition [::2 ] + partition [1 ::2 ][::- 1 ]
1128+ global_partition_lst [idx ] = ordered_partition
11121129 # reorder based on index. The data will be automatically equally partitioned by dispatch function
11131130 global_idx = [j for partition in global_partition_lst for j in partition ]
11141131 global_balance_stats = log_seqlen_unbalance (
@@ -1248,7 +1265,6 @@ def fit(self):
12481265 base_get_meta_kwargs = dict (
12491266 batch_size = self .config .data .train_batch_size * self .config .actor_rollout_ref .rollout .n ,
12501267 partition_id = f"train_{ self .global_steps - 1 } " , # self.global_steps starts from 1
1251- get_n_samples = False ,
12521268 )
12531269
12541270 with marked_timer ("start_profile" , timing_raw ):
@@ -1646,7 +1662,6 @@ def fit(self):
16461662 batch_size = self .config .data .train_batch_size
16471663 * self .config .actor_rollout_ref .rollout .n ,
16481664 partition_id = f"train_{ self .global_steps - 1 } " ,
1649- get_n_samples = False ,
16501665 task_name = "update_actor" ,
16511666 )
16521667 )
@@ -1672,7 +1687,6 @@ def fit(self):
16721687 data_fields = data_fields ,
16731688 batch_size = self .config .data .train_batch_size * self .config .actor_rollout_ref .rollout .n ,
16741689 partition_id = f"train_{ self .global_steps - 1 } " ,
1675- get_n_samples = False ,
16761690 task_name = "log_rollout" ,
16771691 )
16781692 )
0 commit comments