Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ class GRPOConfig(TrainingArguments):
log-probability ratios across valid tokens to produce a single ratio per sequence. The [GSPO
paper](https://huggingface.co/papers/2507.18071) shows that sequence-level sampling often yields more
stable training and better alignment with sequence-level rewards.
mgpo_reward_base (`float`, *optional*):
The reward amount considered to be successful for MGPO scaling. If `None`, regular advantages are used. MGPO is introduced in
[Tiny Model, Big Logic](https://huggingface.co/papers/2511.06221) to stabilize training and encourage exploration
with the GRPO framework.
reward_weights (`list[float]`, *optional*):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
Expand Down Expand Up @@ -645,6 +649,12 @@ class GRPOConfig(TrainingArguments):
"sequence-level rewards."
},
)
mgpo_reward_base: float | None = field(
default=None,
metadata={
"help": "The reward amount considered to be successful for MGPO scaling. If `None`, regular advantages are used"
},
)
reward_weights: list[float] | None = field(
default=None,
metadata={
Expand Down
8 changes: 8 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def __init__(
self.multi_objective_aggregation = args.multi_objective_aggregation
self.scale_rewards = args.scale_rewards
self.importance_sampling_level = args.importance_sampling_level
self.mgpo_reward_base = args.mgpo_reward_base
self.off_policy_mask_threshold = args.off_policy_mask_threshold
if self.use_liger_kernel and self.off_policy_mask_threshold is not None:
raise ValueError("Liger kernel does not support off-policy sequence masking yet.")
Expand Down Expand Up @@ -1777,6 +1778,13 @@ def _generate_and_score_completions(
f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}. Must be "
"'sum_then_normalize' or 'normalize_then_sum'."
)

if self.mgpo_reward_base:
pc = (rewards.detach() >= self.mgpo_reward_base).float().view(-1, self.num_generations).mean(dim=1)
p0 = 0.5
d_me = pc * torch.log(pc / p0 + 1e-8) + (1 - pc) * torch.log((1 - pc) / (1 - p0) + 1e-8)
w_me = torch.exp(-0.5 * d_me).repeat_interleave(self.num_generations)
advantages = advantages * w_me

# Slice to keep only the local part of the data
process_slice = slice(
Expand Down