Skip to content

Fix ppo #10935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed

Fix ppo #10935

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llm/config/qwen/ppo_argument.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ do_train: true # Whether to perform training
seed: 42 # Random seed for reproducibility
global_batch_size: 256 # Global batch size for training (rollouts = rollout_n * global_batch_size)
global_gen_batch_size: -1 # Global generation batch size for dynamic sampling
global_mini_batch_size: 64 # Mini-batch size for training, default = (global_batch_size * rollout_n * update_iters) // dataset_world_size
rollout_n: 1 # Number of rollouts, set rollout_n = 1 for 'ppo'
global_mini_batch_size: 64 # Mini-batch size for training, default = global_batch_size
rollout_n: 1 # Number of rollouts
update_iters: 1 # Number of training iterations for rollout samples
per_device_logprob_batch_size: 4 # Log probability batch size per device
per_device_reward_batch_size: 2 # Reward batch size per device
Expand Down
4 changes: 3 additions & 1 deletion paddlenlp/rl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,9 @@ def init_train_num(
len_dataloader = None
if not self._is_iterable_dataset(self.train_dataset):
len_dataloader = len(train_dataloader)
num_train_sub_steps = self.args.global_mini_batch_size // args.per_device_train_batch_size
num_train_sub_steps = (
args.global_mini_batch_size * args.rollout_n * args.update_iters
) // args.per_device_train_batch_size
num_update_steps_per_epoch = (num_train_sub_steps // args.gradient_accumulation_steps) * len_dataloader
num_examples = len(self.train_dataset)
if args.max_steps > 0:
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/rl/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __post_init__(self):
self._post_init_parallel_degree()

if self.global_mini_batch_size < 0:
self.global_mini_batch_size = self.global_batch_size // self.dataset_world_size
self.global_mini_batch_size = self.global_batch_size

if (
self.global_batch_size % self.dataset_world_size != 0
Expand Down
Loading