diff --git a/llm/config/qwen/ppo_argument.yaml b/llm/config/qwen/ppo_argument.yaml index e82e4316290c..2fa9f951b2d7 100644 --- a/llm/config/qwen/ppo_argument.yaml +++ b/llm/config/qwen/ppo_argument.yaml @@ -50,8 +50,8 @@ do_train: true # Whether to perform training seed: 42 # Random seed for reproducibility global_batch_size: 256 # Global batch size for training (rollouts = rollout_n * global_batch_size) global_gen_batch_size: -1 # Global generation batch size for dynamic sampling -global_mini_batch_size: 64 # Mini-batch size for training, default = (global_batch_size * rollout_n * update_iters) // dataset_world_size -rollout_n: 1 # Number of rollouts, set rollout_n = 1 for 'ppo' +global_mini_batch_size: 64 # Mini-batch size for training, default = global_batch_size +rollout_n: 1 # Number of rollouts update_iters: 1 # Number of training iterations for rollout samples per_device_logprob_batch_size: 4 # Log probability batch size per device per_device_reward_batch_size: 2 # Reward batch size per device diff --git a/paddlenlp/rl/trainer/ppo_trainer.py b/paddlenlp/rl/trainer/ppo_trainer.py index 2ccd738f44df..32ba1e16d222 100644 --- a/paddlenlp/rl/trainer/ppo_trainer.py +++ b/paddlenlp/rl/trainer/ppo_trainer.py @@ -1075,7 +1075,9 @@ def init_train_num( len_dataloader = None if not self._is_iterable_dataset(self.train_dataset): len_dataloader = len(train_dataloader) - num_train_sub_steps = self.args.global_mini_batch_size // args.per_device_train_batch_size + num_train_sub_steps = ( + args.global_mini_batch_size * args.rollout_n * args.update_iters + ) // args.per_device_train_batch_size num_update_steps_per_epoch = (num_train_sub_steps // args.gradient_accumulation_steps) * len_dataloader num_examples = len(self.train_dataset) if args.max_steps > 0: diff --git a/paddlenlp/rl/utils/config_utils.py b/paddlenlp/rl/utils/config_utils.py index 9aae90c9001b..d059228277f0 100644 --- a/paddlenlp/rl/utils/config_utils.py +++ b/paddlenlp/rl/utils/config_utils.py @@ -346,7 +346,7 @@ def __post_init__(self): self._post_init_parallel_degree() if self.global_mini_batch_size < 0: - self.global_mini_batch_size = self.global_batch_size // self.dataset_world_size + self.global_mini_batch_size = self.global_batch_size if ( self.global_batch_size % self.dataset_world_size != 0