|
| 1 | +"""SAPO policy loss function. |
| 2 | +Soft Adaptive Policy Optimization (SAPO) is a reinforcement learning algorithm |
| 3 | +that uses a smooth, temperature-controlled soft gate instead of hard clipping. |
| 4 | +
|
| 5 | +Refer to the SAPO paper for details. https://arxiv.org/abs/2511.20347 |
| 6 | +""" |
| 7 | + |
| 8 | +from typing import Dict, Tuple |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn |
| 13 | +from trinity.algorithm.utils import aggregate_loss, masked_mean |
| 14 | + |
| 15 | + |
| 16 | +@POLICY_LOSS_FN.register_module("sapo") |
| 17 | +class SAPOPolicyLossFn(PolicyLossFn): |
| 18 | + def __init__( |
| 19 | + self, |
| 20 | + backend: str = "verl", |
| 21 | + tau_pos: float = 1.0, |
| 22 | + tau_neg: float = 1.05, |
| 23 | + loss_agg_mode: str = "token-mean", |
| 24 | + ) -> None: |
| 25 | + """Initialize SAPO policy loss function. |
| 26 | +
|
| 27 | + Args: |
| 28 | + backend: The training framework/backend to use (e.g., "verl") |
| 29 | + tau_pos: Temperature for positive advantages (τ_pos), default 1.0 |
| 30 | + tau_neg: Temperature for negative advantages (τ_neg), default 1.05, should be >= tau_pos |
| 31 | + loss_agg_mode: Mode for aggregating loss across tokens |
| 32 | + """ |
| 33 | + super().__init__(backend=backend) |
| 34 | + self.tau_pos = tau_pos |
| 35 | + self.tau_neg = tau_neg |
| 36 | + self.loss_agg_mode = loss_agg_mode |
| 37 | + |
| 38 | + # Validate that tau_neg > tau_pos for stability |
| 39 | + assert self.tau_neg >= self.tau_pos, ( |
| 40 | + f"tau_neg ({self.tau_neg}) should be >= tau_pos ({self.tau_pos}) " |
| 41 | + "for better training stability" |
| 42 | + ) |
| 43 | + |
| 44 | + def soft_gate_function(self, ratio: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor: |
| 45 | + """Compute the soft gate function f_{i,t}(x). |
| 46 | +
|
| 47 | + The soft gate function is defined as: |
| 48 | + f_{i,t}(x) = σ(τ_{i,t} * (x - 1)) * 4 / τ_{i,t} |
| 49 | +
|
| 50 | + where: |
| 51 | + - σ is the sigmoid function |
| 52 | + - τ_{i,t} is the asymmetric temperature (tau_pos or tau_neg) |
| 53 | + - x is the importance sampling ratio r_{i,t}(θ) |
| 54 | +
|
| 55 | + Args: |
| 56 | + ratio: Token-level importance sampling ratio r_{i,t}(θ) |
| 57 | + advantages: Normalized advantage function Â_i (same for all tokens in a sequence) |
| 58 | +
|
| 59 | + Returns: |
| 60 | + The soft gate values for each token |
| 61 | + """ |
| 62 | + # Select temperature based on advantage sign |
| 63 | + # tau_i,t = tau_pos if A_i > 0, else tau_neg |
| 64 | + tau = torch.where( |
| 65 | + advantages > 0, |
| 66 | + torch.tensor(self.tau_pos, device=ratio.device, dtype=ratio.dtype), |
| 67 | + torch.tensor(self.tau_neg, device=ratio.device, dtype=ratio.dtype), |
| 68 | + ) |
| 69 | + |
| 70 | + # Compute sigmoid(tau * (ratio - 1)) |
| 71 | + sigmoid_input = tau * (ratio - 1) |
| 72 | + sigmoid_output = torch.sigmoid(sigmoid_input) |
| 73 | + |
| 74 | + # Compute the soft gate: sigma(tau * (x - 1)) * 4 / tau |
| 75 | + soft_gate = sigmoid_output * (4.0 / tau) |
| 76 | + |
| 77 | + return soft_gate |
| 78 | + |
| 79 | + def __call__( # type: ignore |
| 80 | + self, |
| 81 | + logprob: torch.Tensor, |
| 82 | + old_logprob: torch.Tensor, |
| 83 | + action_mask: torch.Tensor, |
| 84 | + advantages: torch.Tensor, |
| 85 | + **kwargs, |
| 86 | + ) -> Tuple[torch.Tensor, Dict]: |
| 87 | + """Compute SAPO policy loss. |
| 88 | +
|
| 89 | + The SAPO objective function is: |
| 90 | + J(θ) = E[1/G Σ 1/|y_i| Σ f_{i,t}(r_{i,t}(θ)) * Â_i] |
| 91 | +
|
| 92 | + We minimize the negative of this objective. |
| 93 | +
|
| 94 | + Args: |
| 95 | + logprob: Log probabilities from current policy π_θ |
| 96 | + old_logprob: Log probabilities from old policy π_{θ_old} |
| 97 | + action_mask: Mask indicating valid tokens |
| 98 | + advantages: Group-normalized advantage function |
| 99 | +
|
| 100 | + Returns: |
| 101 | + loss: The computed policy loss (negative of objective) |
| 102 | + metrics: Dictionary of metrics for logging |
| 103 | + """ |
| 104 | + # Compute token-level importance sampling ratio |
| 105 | + # r_{i,t}(θ) = π_θ(y_{i,t}|q, y_{i,<t}) / π_{θ_old}(y_{i,t}|q, y_{i,<t}) |
| 106 | + negative_approx_kl = logprob - old_logprob |
| 107 | + ratio = torch.exp(negative_approx_kl) |
| 108 | + |
| 109 | + # Compute approximate KL divergence for monitoring |
| 110 | + ppo_kl = masked_mean(-negative_approx_kl, action_mask) |
| 111 | + |
| 112 | + # Compute soft gate function |
| 113 | + soft_gate = self.soft_gate_function(ratio, advantages) |
| 114 | + |
| 115 | + # SAPO loss: -E[f_{i,t}(r_{i,t}) * Â_i] |
| 116 | + # We multiply by logprob to get the policy gradient |
| 117 | + # The gradient of log π_θ gives us the policy gradient direction |
| 118 | + sapo_loss = -advantages * soft_gate.detach() * logprob |
| 119 | + |
| 120 | + # Aggregate loss across tokens |
| 121 | + loss = aggregate_loss(sapo_loss, action_mask, loss_agg_mode=self.loss_agg_mode) |
| 122 | + |
| 123 | + # Compute metrics for logging |
| 124 | + avg_soft_gate = masked_mean(soft_gate, action_mask) |
| 125 | + avg_ratio = masked_mean(ratio, action_mask) |
| 126 | + |
| 127 | + # Compute fraction of tokens with positive/negative advantages |
| 128 | + pos_adv_frac = masked_mean((advantages > 0).float(), action_mask) |
| 129 | + |
| 130 | + metrics = { |
| 131 | + "sapo_loss": loss.detach().item(), |
| 132 | + "ppo_kl": ppo_kl.detach().item(), |
| 133 | + "avg_soft_gate": avg_soft_gate.detach().item(), |
| 134 | + "avg_ratio": avg_ratio.detach().item(), |
| 135 | + "pos_adv_frac": pos_adv_frac.detach().item(), |
| 136 | + } |
| 137 | + |
| 138 | + return loss, metrics |
| 139 | + |
| 140 | + @classmethod |
| 141 | + def default_args(cls) -> Dict: |
| 142 | + """Get default initialization arguments for SAPO. |
| 143 | +
|
| 144 | + Default configuration (from the SAPO paper): |
| 145 | + - tau_pos: 1.0 (temperature for positive advantages) |
| 146 | + - tau_neg: 1.05 (temperature for negative advantages) |
| 147 | + - loss_agg_mode: "token-mean" (average over tokens) |
| 148 | +
|
| 149 | + The asymmetric temperatures (tau_neg > tau_pos) help stabilize training |
| 150 | + by more aggressively suppressing updates from tokens with negative advantages. |
| 151 | +
|
| 152 | + Returns: |
| 153 | + Dictionary of default arguments |
| 154 | + """ |
| 155 | + return { |
| 156 | + "tau_pos": 1.0, |
| 157 | + "tau_neg": 1.05, |
| 158 | + "loss_agg_mode": "token-mean", |
| 159 | + } |
0 commit comments