diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index b1484d4f41..cc5d95972f 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -115,3 +115,20 @@ def test_mix_policy_loss(self): ) self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) + + def test_sapo_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("sapo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + sapo_loss = torch.tensor(-0.05128994956612587) + ppo_kl = torch.tensor(-0.21663446724414825) + avg_soft_gate = torch.tensor(2.3191137313842773) + avg_ratio = torch.tensor(1.630766749382019) + pos_adv_frac = torch.tensor(0.3958333432674408) + self.assertTrue(torch.allclose(loss, sapo_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["sapo_loss"]), sapo_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["avg_soft_gate"]), avg_soft_gate)) + self.assertTrue(torch.allclose(torch.tensor(metrics["avg_ratio"]), avg_ratio)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pos_adv_frac"]), pos_adv_frac)) diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 4384da7b8e..397336408d 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -250,6 +250,33 @@ def default_config(cls) -> Dict: } +@ALGORITHM_TYPE.register_module("sapo") +class SAPOAlgorithm(AlgorithmType): + """SAPO (Soft Adaptive Policy Optimization) algorithm. + + SAPO uses a smooth, temperature-controlled soft gate instead of hard clipping + to stabilize training while maintaining effective learning. + """ + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = False + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "advantage_fn": "grpo", + "sample_strategy": "default", + "policy_loss_fn": "sapo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } + + @ALGORITHM_TYPE.register_module("mix") class MIXAlgorithm(AlgorithmType): """MIX algorithm.""" diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 4f7b70b917..124a23fff3 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -11,6 +11,7 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.rec_policy_loss import RECPolicyLossFn +from trinity.algorithm.policy_loss_fn.sapo_policy_loss import SAPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn @@ -31,4 +32,5 @@ "SFTPhiLossFn", "sPPOPolicyLossFn", "RECPolicyLossFn", + "SAPOPolicyLossFn", ] diff --git a/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py b/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py new file mode 100644 index 0000000000..7d5c4a9598 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py @@ -0,0 +1,159 @@ +"""SAPO policy loss function. +Soft Adaptive Policy Optimization (SAPO) is a reinforcement learning algorithm +that uses a smooth, temperature-controlled soft gate instead of hard clipping. + +Refer to the SAPO paper for details. https://arxiv.org/abs/2511.20347 +""" + +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import aggregate_loss, masked_mean + + +@POLICY_LOSS_FN.register_module("sapo") +class SAPOPolicyLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + tau_pos: float = 1.0, + tau_neg: float = 1.05, + loss_agg_mode: str = "token-mean", + ) -> None: + """Initialize SAPO policy loss function. + + Args: + backend: The training framework/backend to use (e.g., "verl") + tau_pos: Temperature for positive advantages (τ_pos), default 1.0 + tau_neg: Temperature for negative advantages (τ_neg), default 1.05, should be >= tau_pos + loss_agg_mode: Mode for aggregating loss across tokens + """ + super().__init__(backend=backend) + self.tau_pos = tau_pos + self.tau_neg = tau_neg + self.loss_agg_mode = loss_agg_mode + + # Validate that tau_neg > tau_pos for stability + assert self.tau_neg >= self.tau_pos, ( + f"tau_neg ({self.tau_neg}) should be >= tau_pos ({self.tau_pos}) " + "for better training stability" + ) + + def soft_gate_function(self, ratio: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor: + """Compute the soft gate function f_{i,t}(x). + + The soft gate function is defined as: + f_{i,t}(x) = σ(τ_{i,t} * (x - 1)) * 4 / τ_{i,t} + + where: + - σ is the sigmoid function + - τ_{i,t} is the asymmetric temperature (tau_pos or tau_neg) + - x is the importance sampling ratio r_{i,t}(θ) + + Args: + ratio: Token-level importance sampling ratio r_{i,t}(θ) + advantages: Normalized advantage function Â_i (same for all tokens in a sequence) + + Returns: + The soft gate values for each token + """ + # Select temperature based on advantage sign + # tau_i,t = tau_pos if A_i > 0, else tau_neg + tau = torch.where( + advantages > 0, + torch.tensor(self.tau_pos, device=ratio.device, dtype=ratio.dtype), + torch.tensor(self.tau_neg, device=ratio.device, dtype=ratio.dtype), + ) + + # Compute sigmoid(tau * (ratio - 1)) + sigmoid_input = tau * (ratio - 1) + sigmoid_output = torch.sigmoid(sigmoid_input) + + # Compute the soft gate: sigma(tau * (x - 1)) * 4 / tau + soft_gate = sigmoid_output * (4.0 / tau) + + return soft_gate + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + """Compute SAPO policy loss. + + The SAPO objective function is: + J(θ) = E[1/G Σ 1/|y_i| Σ f_{i,t}(r_{i,t}(θ)) * Â_i] + + We minimize the negative of this objective. + + Args: + logprob: Log probabilities from current policy π_θ + old_logprob: Log probabilities from old policy π_{θ_old} + action_mask: Mask indicating valid tokens + advantages: Group-normalized advantage function + + Returns: + loss: The computed policy loss (negative of objective) + metrics: Dictionary of metrics for logging + """ + # Compute token-level importance sampling ratio + # r_{i,t}(θ) = π_θ(y_{i,t}|q, y_{i, 0).float(), action_mask) + + metrics = { + "sapo_loss": loss.detach().item(), + "ppo_kl": ppo_kl.detach().item(), + "avg_soft_gate": avg_soft_gate.detach().item(), + "avg_ratio": avg_ratio.detach().item(), + "pos_adv_frac": pos_adv_frac.detach().item(), + } + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + """Get default initialization arguments for SAPO. + + Default configuration (from the SAPO paper): + - tau_pos: 1.0 (temperature for positive advantages) + - tau_neg: 1.05 (temperature for negative advantages) + - loss_agg_mode: "token-mean" (average over tokens) + + The asymmetric temperatures (tau_neg > tau_pos) help stabilize training + by more aggressively suppressing updates from tokens with negative advantages. + + Returns: + Dictionary of default arguments + """ + return { + "tau_pos": 1.0, + "tau_neg": 1.05, + "loss_agg_mode": "token-mean", + }