-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[grpo] support gigpo with gym #7364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -517,13 +517,14 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): | |
| tau_neg (float): The temperature parameter for negative dominance in the SAPO algorithm, controlling the | ||
| sharpness of the soft gating function. Typically, `tau_neg` is set > `tau_pos` to impose stronger | ||
| constraints on negative dominance. The default value is 1.05. | ||
| advantage_estimator (Literal['grpo', 'rloo', 'reinforce_plus_plus']): The advantage estimation | ||
| step_advantage_w (float): The weight for the step-level advantage (A^S) in the GiGPO algorithm. Defaults to 1.0. | ||
| advantage_estimator (Literal['grpo', 'rloo', 'reinforce_plus_plus', 'gigpo']): The advantage estimation | ||
| function to use. 'grpo' calculates the relative advantage within a group. Options are 'grpo', 'rloo', | ||
| 'reinforce_plus_plus'. Defaults to 'grpo'. | ||
| kl_in_reward (Optional[bool]): Controls how the KL divergence regularization term is handled. If | ||
| `False`, it's an independent term in the loss function. If `True`, KL is directly incorporated into the | ||
| reward (subtracted from it). The default is tied to `advantage_estimator`: `False` for 'grpo', `True` for | ||
| 'rloo' and 'reinforce_plus_plus'. | ||
| reward (subtracted from it). The default is tied to `advantage_estimator`: `False` for 'grpo' and 'gigpo', | ||
| `True` for 'rloo' and 'reinforce_plus_plus'. | ||
| generation_batch_size (Optional[int]): The batch size for sampling completions. It should be a | ||
| multiple of `num_processes * per_device_train_batch_size`. Defaults to `per_device_batch_size * | ||
| gradient_accumulation_steps * num_processes`. | ||
|
|
@@ -596,10 +597,12 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): | |
| tau_pos: float = 1.0 | ||
| tau_neg: float = 1.05 | ||
|
|
||
| # RLOO, REINFORCE++ | ||
| advantage_estimator: Literal['grpo', 'rloo', 'reinforce_plus_plus'] = 'grpo' | ||
| # RLOO, REINFORCE++, GiGPO | ||
| advantage_estimator: Literal['grpo', 'rloo', 'reinforce_plus_plus', 'gigpo'] = 'grpo' | ||
| # If false, add KL into loss, otherwise add into reward | ||
| kl_in_reward: Optional[bool] = None # rloo/reinforce_plus_plus: true, grpo: false (default) | ||
| # GiGPO, https://arxiv.org/abs/2405.06708 | ||
|
||
| step_advantage_w = 1.0 | ||
|
||
|
|
||
| generation_batch_size: Optional[int] = None | ||
| steps_per_generation: Optional[int] = None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| import inspect | ||
| import os | ||
| import time | ||
| import numpy as np | ||
| from collections import defaultdict, deque | ||
| from contextlib import contextmanager, nullcontext | ||
| from copy import copy, deepcopy | ||
|
|
@@ -425,6 +426,52 @@ def normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tensor) -> | |
| return advantages / (rewards_std + 1e-4) | ||
| return advantages | ||
|
|
||
| def _compute_step_advantages(inputs, trajectory_advantages): | ||
| # Extract step-level reward information from inputs | ||
| # Store (prompt_id, step) -> [rewards] mapping | ||
| step_rewards_dict = {} | ||
| for idx, input_data in enumerate(inputs): | ||
| prompt_id = input_data['prompt_id'] | ||
| rollout_info = input_data['rollout_infos'] | ||
|
|
||
| # Collect all step rewards for current trajectory | ||
| for traj_info in rollout_info.get('trajectory_info', []): | ||
| step = traj_info.get('step', 0) | ||
| reward = traj_info.get('reward', 0.0) | ||
|
|
||
| # Group rewards by prompt_id and step | ||
| key = (prompt_id, step) | ||
| if key not in step_rewards_dict: | ||
| step_rewards_dict[key] = [] | ||
| step_rewards_dict[key].append(reward) | ||
| # Calculate step-level advantage and aggregate | ||
| aggregated_step_advantages = torch.zeros_like(trajectory_advantages) | ||
| for idx, input_data in enumerate(inputs): | ||
| prompt_id = input_data['prompt_id'] | ||
| rollout_info = input_data['rollout_infos'] | ||
|
|
||
| # Calculate aggregated step-level advantage for current trajectory | ||
| step_advantages = [] | ||
| for traj_info in rollout_info.get('trajectory_info', []): | ||
| step = traj_info.get('step', 0) | ||
| reward = traj_info.get('reward', 0.0) | ||
|
|
||
| # Get all rewards for same prompt and step | ||
| key = (prompt_id, step) | ||
| all_rewards = step_rewards_dict.get(key, [reward]) | ||
|
|
||
| # Calculate step advantage (compared to group average) | ||
| mean_reward = np.mean(all_rewards) | ||
| step_advantage = reward - mean_reward | ||
| step_advantages.append(step_advantage) | ||
|
|
||
| # Aggregate step-level advantage for current trajectory (use mean of valid steps) | ||
| if step_advantages: | ||
| aggregated_step_advantages[idx] = np.mean(step_advantages) | ||
| else: | ||
| aggregated_step_advantages[idx] = 0.0 | ||
| return aggregated_step_advantages | ||
|
|
||
| def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: torch.Tensor): | ||
| """Log reward statistics for monitoring. Only log once per unique request_id.""" | ||
| # rewards: [prompt_batch_size, num_generations] | ||
|
|
@@ -506,6 +553,12 @@ def log_rewards_all(rewards_per_func: torch.Tensor): | |
| advantages = rewards * K / (K - 1) - group_rewards_mean * K / (K - 1) | ||
| else: | ||
| advantages = rewards - group_rewards_mean | ||
| elif self.advantage_estimator == 'gigpo' and self.use_gym_env: | ||
|
||
| # Get trajectory-level advantage (original GRPO advantage) | ||
| trajectory_advantages = rewards - group_rewards_mean | ||
| aggregated_step_advantages = _compute_step_advantages(inputs, trajectory_advantages) | ||
| # Weighted sum of trajectory-level advantage and aggregated step-level advantage | ||
| advantages = trajectory_advantages + self.step_advantage_w * aggregated_step_advantages | ||
| else: # 'grpo' or 'reinforce_plus_plus' | ||
| # Both use group mean as baseline | ||
| advantages = rewards - group_rewards_mean | ||
|
|
@@ -654,6 +707,13 @@ def log_rewards_all(rewards_per_func: torch.Tensor): | |
| indices_in_unique = torch.tensor([rid_to_idx[r] for r in request_ids], device=device) | ||
| advantages = request_advantages[indices_in_unique] | ||
|
|
||
| if self.advantage_estimator == 'gigpo' and self.use_gym_env: | ||
| # Get trajectory-level advantage (original GRPO advantage) | ||
| trajectory_advantages = advantages | ||
| aggregated_step_advantages = _compute_step_advantages(inputs, trajectory_advantages) | ||
| # Weighted sum of trajectory-level advantage and aggregated step-level advantage | ||
| advantages = trajectory_advantages + self.step_advantage_w * aggregated_step_advantages | ||
|
||
|
|
||
| # Step 5. Log metrics for unique request_ids | ||
| log_rewards_metrics(rewards=unique_rewards, rewards_per_func_for_metrics=rewards_per_func[unique_indices]) | ||
|
|
||
|
|
@@ -2154,6 +2214,9 @@ def _prepare_algorithm_params(self): | |
| self.advantage_estimator = args.advantage_estimator | ||
| self.kl_in_reward = args.kl_in_reward | ||
|
|
||
| # GiGPO, https://arxiv.org/abs/2405.06708 | ||
|
||
| self.step_advantage_w = args.step_advantage_w | ||
|
|
||
| # Rollout Importance Sampling Correction | ||
| self.rollout_importance_sampling_mode = args.rollout_importance_sampling_mode | ||
| self.rollout_importance_sampling_threshold = args.rollout_importance_sampling_threshold | ||
|
|
@@ -2227,7 +2290,10 @@ def _prepare_rewards(self, reward_funcs, reward_model=None, reward_templates=Non | |
| f'functions ({len(reward_funcs)})') | ||
| self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32).to(device) | ||
| else: | ||
| self.reward_weights = torch.ones(len(self.reward_func_names), dtype=torch.float32).to(device) | ||
| if self.use_gym_env: | ||
| self.reward_weights = torch.ones(1, dtype=torch.float32).to(device) | ||
| else: | ||
| self.reward_weights = torch.ones(len(self.reward_func_names), dtype=torch.float32).to(device) | ||
|
|
||
| # after init trainer | ||
| for i, reward_func in enumerate(self.reward_funcs): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment for
kl_in_rewardis now outdated with the addition ofgigpo. The docstring was updated to reflect thatkl_in_rewardisFalseforgigpo, but this inline comment was missed. Please update it for consistency and to avoid confusion for future developers.