-
Notifications
You must be signed in to change notification settings - Fork 230
New Algorithm: MAPO implementation #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 16 commits
5f74710
e8681a9
05bf6b2
a9de698
62e7cfc
e4cc054
a92564b
e1bbe4b
b6d0228
fba252f
df71ffb
aeb3222
1b575e6
ed51e5f
ff69805
488e652
63cd9a3
987044e
71f51de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |
| from einops import rearrange | ||
| from torchdata.stateful_dataloader import StatefulDataLoader | ||
|
|
||
| from areal.api.cli_args import MicroBatchSpec, NormConfig | ||
| from areal.api.cli_args import MicroBatchSpec, NormConfig, PPOActorConfig | ||
| from areal.platforms import current_platform | ||
| from areal.utils import datapack, logging | ||
|
|
||
|
|
@@ -1070,6 +1070,7 @@ def cycle_dataloader(dataloader: StatefulDataLoader): | |
| g = iter(dataloader) | ||
|
|
||
|
|
||
| # base native normalization implementation (for both reward and adv norm) | ||
| class Normalization: | ||
| """ | ||
| Adaptive normalization with different levels. | ||
|
|
@@ -1108,7 +1109,11 @@ def __call__( | |
| loss_mask: Optional[torch.Tensor] = None, | ||
| high_precision: bool = True, | ||
| reduce_group=None, | ||
| calculation_base: str = "deviation", | ||
| ) -> torch.Tensor: | ||
|
|
||
| # x can be advantage or reward in shape [bs*self.group_size, max_tokens] | ||
|
|
||
| bs = x.size(0) | ||
| eps = self.eps | ||
|
|
||
|
|
@@ -1200,8 +1205,15 @@ def __call__( | |
| std = torch.ones_like(x) | ||
| eps = 0.0 | ||
|
|
||
| assert calculation_base in [ | ||
| "mean", | ||
| "deviation", | ||
| ], "calculation_base must be either mean or deviation" | ||
| base = std if calculation_base == "deviation" else mean | ||
| # Ensure stability | ||
| base += eps | ||
| # Normalize | ||
| return (x_centered / (std + eps)).float() | ||
| return (x_centered / base).float() | ||
|
|
||
| @staticmethod | ||
| def _compute_mean( | ||
|
|
@@ -1362,3 +1374,115 @@ def _compute_approx_kl( | |
| if apply_clamp: | ||
| log_ratio = log_ratio.clamp(min=-10, max=10) | ||
| return log_ratio | ||
|
|
||
|
|
||
| # the mixed adv norm implementation to paper MAPO, derived from base native normalization implementation | ||
| class MAPOAdvNorm(Normalization): | ||
| def __call__(self, advantages, loss_mask=None, **kwargs) -> torch.Tensor: | ||
| # Calculate the unique number of elements in advantages Tensor,exclude element of 0 (because 0 means adv over pad_token) | ||
|
|
||
| # deviation_base_norm shape [batch_size*group_size, max_token] | ||
| deviation_base_norm = super().__call__( | ||
| advantages, loss_mask=loss_mask, calculation_base="deviation", **kwargs | ||
| ) | ||
|
|
||
| unique_elements = torch.unique(advantages[advantages != 0]).numel() | ||
|
|
||
| if unique_elements >= 3 or unique_elements <= 1: | ||
| # means all advantages are same but not 0 | ||
| if unique_elements >= 3: | ||
| logger.warning( | ||
| ( | ||
ZiyiTsang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| f"The MAPO only support reward modeling in a binary, but detected {unique_elements} unique elements in advantages Tensor. Please check: " | ||
ZiyiTsang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| f"1. the definition of reward_fun: return the binary number " | ||
| f"2. overlong_reward_panalty set to false" | ||
ZiyiTsang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
| ) | ||
| # means all advantages are same but not 0 | ||
| else: | ||
| logger.info( | ||
| ( | ||
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| f"the advantage are all same in the batch, please check your reward function" | ||
| ) | ||
| ) | ||
|
|
||
| logger.info((f"falling back to native advantage normalization")) | ||
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # fall back to native implementation is ok | ||
| return super().__call__( | ||
| advantages, loss_mask=loss_mask, calculation_base="deviation", **kwargs | ||
| ) | ||
|
|
||
| # the 'unique_upper_value' means the reward of success trajectory | ||
| unique_upper_value, unique_lower_value = ( | ||
| max(unique_elements).item(), | ||
| min(unique_elements).item(), | ||
| ) | ||
|
|
||
| assert unique_elements <= 2, ( | ||
| f"The MAPO only support reward modeling in a binary, but detected {unique_elements} unique elements in advantages Tensor. Please check: " | ||
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| f"1. the definition of reward_fun: return the binary number " | ||
| f"2. overlong_reward_panalty set to false" | ||
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| # mean_base_norm shape [batch_size*group_size, max_token] | ||
| mean_base_norm = super().__call__( | ||
| advantages, loss_mask=loss_mask, calculation_base="mean", **kwargs | ||
| ) | ||
|
|
||
| bs, max_token = int(advantages.shape[0] / self.group_size), advantages.shape[-1] | ||
|
|
||
| # since the advantages is same within same trajectory, we can get the trajectory_level advantage from first token | ||
| # base on assumption that the advantage on last dim are totally same | ||
|
|
||
| advantages_ = advantages[:, 0] # advantages shape [batch_size*group_size] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line assumes that the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We exactly assume that the advantages tensor is constant across the time dimension for each trajectory. Not good for PPO but GRPO is make sense.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line does not take any effect and should be removed.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, useful. Please see the code comment.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The adv of first token is extract and use for below logic. |
||
|
|
||
| advantages_ = advantages_.reshape( | ||
| bs, self.group_size | ||
| ) # advantages shape [batch_size, group_size] | ||
|
|
||
| # the number of sucess trajectory within each group and batch | ||
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| success_trajectory_nums_per_group = (advantages_ == unique_upper_value).sum( | ||
| dim=1 | ||
| ) # success_trajectory_nums shape [batch_size] | ||
| # the number of total trajectory within each group | ||
| total_trajectory_nums_per_group = torch.tensor([self.group_size] * bs).to( | ||
| device=success_trajectory_nums_per_group.device, | ||
| dtype=success_trajectory_nums_per_group.dtype, | ||
| ) # total_trajectory_nums shape [batch_size] | ||
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # the probability of success trajectory within each group and batch | ||
| trajectory_certainty_degree = ( | ||
| success_trajectory_nums_per_group / total_trajectory_nums_per_group | ||
| ) | ||
|
|
||
| # trajectory_reweight shape [batch_size], represent the reweight of tragetories | ||
| trajectory_reweight = ( | ||
| 4 * trajectory_certainty_degree * (1 - trajectory_certainty_degree) | ||
| ) | ||
| # trajectory_reweight shape to expand each_token of advantages | ||
| # trajectory_reweight [batch_size]->[batch_size*group_size]->[batch_size*group_size, max_token],each trajectory has same reweight for each token. | ||
| # i.e. trajectory_reweight granularity: group-level-> trajectory-level->token-level | ||
| trajectory_reweight = ( | ||
| trajectory_reweight.repeat_interleave(self.group_size) | ||
| .unsqueeze(-1) | ||
| .expand(-1, max_token) | ||
| ) | ||
| # in this case 'trajectory_reweight' & 'deviation_base_norm' & 'mean_base_norm' have the same granularity | ||
| # torch auto broadcasting will automatically expand the dimension to do the calculation | ||
| return ( | ||
| 1 - trajectory_reweight | ||
| ) * deviation_base_norm + trajectory_reweight * mean_base_norm | ||
|
Comment on lines
+1469
to
+1471
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. double-check the formula. Since the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my mistake. thank you |
||
|
|
||
|
|
||
| def get_reward_norm(config: PPOActorConfig): | ||
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if config.reward_norm: | ||
| return Normalization(config.reward_norm) | ||
| else: | ||
| return None | ||
|
|
||
|
|
||
| def get_adv_norm(config: PPOActorConfig): | ||
| if config.adv_norm: | ||
| if config.adv_norm.adv_norm_mode == "mix": | ||
| return MAPOAdvNorm(config.adv_norm) | ||
| else: | ||
| return Normalization(config.adv_norm) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| # Mixed Advantage Policy Optimization (MAPO) | ||
|
|
||
| Last updated: Sep 27, 2025 | ||
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Author: [Ziyi ZENG](https://github.com/ZiyiTsang) | ||
|
|
||
|  | ||
|
|
||
| This paper introduces Mixed Advantage Policy Optimization (MAPO), an improved Group Relative Policy Optimization (GRPO) strategy designed to enhance the reasoning performance of foundation models. While GRPO has been effective in post-training foundation models for reasoning tasks, it suffers from "advantage reversion" and "advantage mirror" problems, which lead to an unreasonable allocation of advantage across different query samples. MAPO addresses these limitations by introducing the concept of "trajectory certainty" and proposing an "Advantage Percent Deviation" (APD) for high-certainty trajectories. Furthermore, it dynamically reweights the advantage function based on trajectory certainty through "Trajectory Certainty Reweight" (TCR). This adaptive approach allows MAPO to configure the advantage function to account for sample-specific characteristics, thereby mitigating the shortcomings of prior advantage function formulations and producing more stable and accurate reasoning performance across diverse tasks. | ||
|
|
||
| The overall surrogate objective is: | ||
|
|
||
|
|
||
| $$\mathcal{J}_{\mathrm{GRPO}}(\theta)=\mathbb{E}_{q\sim\rho_{Q}}\mathbb{E}_{o\sim\pi_{old}(\cdot|q)}\left[\frac{1}{G}\sum_{i}^{G}f_{\epsilon}\left(\frac{\pi_{\theta}(o_{i}|q)}{\pi_{old}(o_{i}|q)},\hat{\Lambda}_{i}\right)\right]-\beta\mathbb{D}_{KL}[\pi_{\theta}||\pi_{ref}],$$ | ||
| where: | ||
| $$f_\epsilon(x,y)=\min(xy,\mathrm{clip}(x,1-\epsilon,1+\epsilon)y)$$ | ||
|
|
||
| $$\lambda(p)=1-4p(1-p)\in[0,1]\quad(p\in[0,1])$$ | ||
|
|
||
| $$\hat{A}_i^*=(1-\lambda(p))*\underbrace{\frac{r_i-\mu}{\sigma}}_{\text{Deviation-based}}+\lambda(p)*\underbrace{\frac{r_i-\mu}{\mu}}_{\text{Mean-based}}.$$ | ||
|
|
||
|
|
||
| For more details: | ||
|
|
||
| - AReal Detail: [Paper of AReal](https://arxiv.org/abs/2505.24298) | ||
|
|
||
| - MAPO Detail: [Paper of MAPO](https://arxiv.org/abs/2509.18849v3) | ||
|
|
||
| ## Algorithm Core Parameters | ||
|
|
||
| - `actor.adv_norm.aggregation_mode`: the implementation of adv_norm. 'native' is the z-score normalization used by GRPO, while 'mix' is the implementation for MAPO. | ||
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| ## Notice | ||
| For MAPO implementation, following constraints should be met: | ||
|
|
||
| 1. 'reward_function' should return binary result of any value. High value represents the successful trajectory, while the lower value represent the fail trajectory. | ||
| 2. the 'overlong_reward_panelty' should be disable | ||
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
ZiyiTsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| ## Example Usage | ||
|
|
||
| We recommend to change the parameter within the configuration file | ||
| (i.e. gsm8k_mapo.yaml). | ||
|
|
||
| | Backend | CMD | | ||
| | --------- | -------------------------------------------------------------------------------------------------------------------------------- | | ||
| | **local** | `python3 -m areal.launcher.local examples/experimental/mapo/gsm8k_mapo.py --config examples/experimental/mapo/gsm8k_mapo.yaml --<other_args_to_overwrite>` | | ||
| | **ray** | `python3 -m areal.launcher.ray examples/experimental/mapo/gsm8k_mapo.py --config examples/experimental/mapo/gsm8k_mapo.yaml --<other_args_to_overwrite>` | | ||
| | **slurm** | `python3 -m areal.launcher.slurm examples/experimental/mapo/gsm8k_mapo.py --config examples/experimental/mapo/gsm8k_mapo.yaml --<other_args_to_overwrite>` | | ||
|
|
||
| ## Baselines | ||
|
|
||
| We still lack baseline, welcome to contribute! | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| experiment_name: gsm8k-grpo | ||
| experiment_name: gsm8k-drgrpo | ||
| trial_name: trial0 | ||
|
|
||
| seed: 1 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| experiment_name: gsm8k-grpo | ||
| experiment_name: gsm8k-liteppo | ||
| trial_name: trial0 | ||
|
|
||
| seed: 1 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.