Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,13 +517,14 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin):
tau_neg (float): The temperature parameter for negative dominance in the SAPO algorithm, controlling the
sharpness of the soft gating function. Typically, `tau_neg` is set > `tau_pos` to impose stronger
constraints on negative dominance. The default value is 1.05.
advantage_estimator (Literal['grpo', 'rloo', 'reinforce_plus_plus']): The advantage estimation
step_advantage_w (float): The weight for the step-level advantage (A^S) in the GiGPO algorithm. Defaults to 1.0.
advantage_estimator (Literal['grpo', 'rloo', 'reinforce_plus_plus', 'gigpo']): The advantage estimation
function to use. 'grpo' calculates the relative advantage within a group. Options are 'grpo', 'rloo',
'reinforce_plus_plus'. Defaults to 'grpo'.
kl_in_reward (Optional[bool]): Controls how the KL divergence regularization term is handled. If
`False`, it's an independent term in the loss function. If `True`, KL is directly incorporated into the
reward (subtracted from it). The default is tied to `advantage_estimator`: `False` for 'grpo', `True` for
'rloo' and 'reinforce_plus_plus'.
reward (subtracted from it). The default is tied to `advantage_estimator`: `False` for 'grpo' and 'gigpo',
`True` for 'rloo' and 'reinforce_plus_plus'.
generation_batch_size (Optional[int]): The batch size for sampling completions. It should be a
multiple of `num_processes * per_device_train_batch_size`. Defaults to `per_device_batch_size *
gradient_accumulation_steps * num_processes`.
Expand Down Expand Up @@ -596,10 +597,12 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin):
tau_pos: float = 1.0
tau_neg: float = 1.05

# RLOO, REINFORCE++
advantage_estimator: Literal['grpo', 'rloo', 'reinforce_plus_plus'] = 'grpo'
# RLOO, REINFORCE++, GiGPO
advantage_estimator: Literal['grpo', 'rloo', 'reinforce_plus_plus', 'gigpo'] = 'grpo'
# If false, add KL into loss, otherwise add into reward
kl_in_reward: Optional[bool] = None # rloo/reinforce_plus_plus: true, grpo: false (default)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment for kl_in_reward is now outdated with the addition of gigpo. The docstring was updated to reflect that kl_in_reward is False for gigpo, but this inline comment was missed. Please update it for consistency and to avoid confusion for future developers.

Suggested change
kl_in_reward: Optional[bool] = None # rloo/reinforce_plus_plus: true, grpo: false (default)
kl_in_reward: Optional[bool] = None # rloo/reinforce_plus_plus: true, grpo/gigpo: false (default)

# GiGPO, https://arxiv.org/abs/2405.06708
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

step_advantage_w = 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gigpo_step_advantage_weight


generation_batch_size: Optional[int] = None
steps_per_generation: Optional[int] = None
Expand Down
68 changes: 67 additions & 1 deletion swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import inspect
import os
import time
import numpy as np
from collections import defaultdict, deque
from contextlib import contextmanager, nullcontext
from copy import copy, deepcopy
Expand Down Expand Up @@ -425,6 +426,52 @@ def normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tensor) ->
return advantages / (rewards_std + 1e-4)
return advantages

def _compute_step_advantages(inputs, trajectory_advantages):
# Extract step-level reward information from inputs
# Store (prompt_id, step) -> [rewards] mapping
step_rewards_dict = {}
for idx, input_data in enumerate(inputs):
prompt_id = input_data['prompt_id']
rollout_info = input_data['rollout_infos']

# Collect all step rewards for current trajectory
for traj_info in rollout_info.get('trajectory_info', []):
step = traj_info.get('step', 0)
reward = traj_info.get('reward', 0.0)

# Group rewards by prompt_id and step
key = (prompt_id, step)
if key not in step_rewards_dict:
step_rewards_dict[key] = []
step_rewards_dict[key].append(reward)
# Calculate step-level advantage and aggregate
aggregated_step_advantages = torch.zeros_like(trajectory_advantages)
for idx, input_data in enumerate(inputs):
prompt_id = input_data['prompt_id']
rollout_info = input_data['rollout_infos']

# Calculate aggregated step-level advantage for current trajectory
step_advantages = []
for traj_info in rollout_info.get('trajectory_info', []):
step = traj_info.get('step', 0)
reward = traj_info.get('reward', 0.0)

# Get all rewards for same prompt and step
key = (prompt_id, step)
all_rewards = step_rewards_dict.get(key, [reward])

# Calculate step advantage (compared to group average)
mean_reward = np.mean(all_rewards)
step_advantage = reward - mean_reward
step_advantages.append(step_advantage)

# Aggregate step-level advantage for current trajectory (use mean of valid steps)
if step_advantages:
aggregated_step_advantages[idx] = np.mean(step_advantages)
else:
aggregated_step_advantages[idx] = 0.0
return aggregated_step_advantages

def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: torch.Tensor):
"""Log reward statistics for monitoring. Only log once per unique request_id."""
# rewards: [prompt_batch_size, num_generations]
Expand Down Expand Up @@ -506,6 +553,12 @@ def log_rewards_all(rewards_per_func: torch.Tensor):
advantages = rewards * K / (K - 1) - group_rewards_mean * K / (K - 1)
else:
advantages = rewards - group_rewards_mean
elif self.advantage_estimator == 'gigpo' and self.use_gym_env:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Gigpo depends on the Gym environment, it's recommended to add proper checks.

- if self.advantage_estimator == 'gigpo' and self.use_gym_env:
+ if self.advantage_estimator == 'gigpo':
+ 	assert self.use_gym_env

# Get trajectory-level advantage (original GRPO advantage)
trajectory_advantages = rewards - group_rewards_mean
aggregated_step_advantages = _compute_step_advantages(inputs, trajectory_advantages)
# Weighted sum of trajectory-level advantage and aggregated step-level advantage
advantages = trajectory_advantages + self.step_advantage_w * aggregated_step_advantages
else: # 'grpo' or 'reinforce_plus_plus'
# Both use group mean as baseline
advantages = rewards - group_rewards_mean
Expand Down Expand Up @@ -654,6 +707,13 @@ def log_rewards_all(rewards_per_func: torch.Tensor):
indices_in_unique = torch.tensor([rid_to_idx[r] for r in request_ids], device=device)
advantages = request_advantages[indices_in_unique]

if self.advantage_estimator == 'gigpo' and self.use_gym_env:
# Get trajectory-level advantage (original GRPO advantage)
trajectory_advantages = advantages
aggregated_step_advantages = _compute_step_advantages(inputs, trajectory_advantages)
# Weighted sum of trajectory-level advantage and aggregated step-level advantage
advantages = trajectory_advantages + self.step_advantage_w * aggregated_step_advantages
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for calculating the GiGPO advantage is nearly identical to the one on lines 556-561. To improve maintainability and avoid code duplication, consider extracting this logic into a helper function.

For example, you could create a helper within _compute_advantages:

def _apply_gigpo_step_advantage(base_advantages, inputs):
    aggregated_step_advantages = _compute_step_advantages(inputs, base_advantages)
    return base_advantages + self.step_advantage_w * aggregated_step_advantages

Then you could call this helper in both places to compute the final advantage for GiGPO.


# Step 5. Log metrics for unique request_ids
log_rewards_metrics(rewards=unique_rewards, rewards_per_func_for_metrics=rewards_per_func[unique_indices])

Expand Down Expand Up @@ -2154,6 +2214,9 @@ def _prepare_algorithm_params(self):
self.advantage_estimator = args.advantage_estimator
self.kl_in_reward = args.kl_in_reward

# GiGPO, https://arxiv.org/abs/2405.06708
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong link

self.step_advantage_w = args.step_advantage_w

# Rollout Importance Sampling Correction
self.rollout_importance_sampling_mode = args.rollout_importance_sampling_mode
self.rollout_importance_sampling_threshold = args.rollout_importance_sampling_threshold
Expand Down Expand Up @@ -2227,7 +2290,10 @@ def _prepare_rewards(self, reward_funcs, reward_model=None, reward_templates=Non
f'functions ({len(reward_funcs)})')
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32).to(device)
else:
self.reward_weights = torch.ones(len(self.reward_func_names), dtype=torch.float32).to(device)
if self.use_gym_env:
self.reward_weights = torch.ones(1, dtype=torch.float32).to(device)
else:
self.reward_weights = torch.ones(len(self.reward_func_names), dtype=torch.float32).to(device)

# after init trainer
for i, reward_func in enumerate(self.reward_funcs):
Expand Down