Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/algorithm/policy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,20 @@ def test_mix_policy_loss(self):
)
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))

def test_sapo_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("sapo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
sapo_loss = torch.tensor(-0.05128994956612587)
ppo_kl = torch.tensor(-0.21663446724414825)
avg_soft_gate = torch.tensor(2.3191137313842773)
avg_ratio = torch.tensor(1.630766749382019)
pos_adv_frac = torch.tensor(0.3958333432674408)
self.assertTrue(torch.allclose(loss, sapo_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["sapo_loss"]), sapo_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
self.assertTrue(torch.allclose(torch.tensor(metrics["avg_soft_gate"]), avg_soft_gate))
self.assertTrue(torch.allclose(torch.tensor(metrics["avg_ratio"]), avg_ratio))
self.assertTrue(torch.allclose(torch.tensor(metrics["pos_adv_frac"]), pos_adv_frac))
27 changes: 27 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,33 @@ def default_config(cls) -> Dict:
}


@ALGORITHM_TYPE.register_module("sapo")
class SAPOAlgorithm(AlgorithmType):
"""SAPO (Soft Adaptive Policy Optimization) algorithm.

SAPO uses a smooth, temperature-controlled soft gate instead of hard clipping
to stabilize training while maintaining effective learning.
"""

use_critic: bool = False
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"advantage_fn": "grpo",
"sample_strategy": "default",
"policy_loss_fn": "sapo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}


@ALGORITHM_TYPE.register_module("mix")
class MIXAlgorithm(AlgorithmType):
"""MIX algorithm."""
Expand Down
2 changes: 2 additions & 0 deletions trinity/algorithm/policy_loss_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.rec_policy_loss import RECPolicyLossFn
from trinity.algorithm.policy_loss_fn.sapo_policy_loss import SAPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn
Expand All @@ -31,4 +32,5 @@
"SFTPhiLossFn",
"sPPOPolicyLossFn",
"RECPolicyLossFn",
"SAPOPolicyLossFn",
]
159 changes: 159 additions & 0 deletions trinity/algorithm/policy_loss_fn/sapo_policy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""SAPO policy loss function.
Soft Adaptive Policy Optimization (SAPO) is a reinforcement learning algorithm
that uses a smooth, temperature-controlled soft gate instead of hard clipping.

Refer to the SAPO paper for details. https://arxiv.org/abs/2511.20347
"""

from typing import Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import aggregate_loss, masked_mean


@POLICY_LOSS_FN.register_module("sapo")
class SAPOPolicyLossFn(PolicyLossFn):
def __init__(
self,
backend: str = "verl",
tau_pos: float = 1.0,
tau_neg: float = 1.05,
loss_agg_mode: str = "token-mean",
) -> None:
"""Initialize SAPO policy loss function.

Args:
backend: The training framework/backend to use (e.g., "verl")
tau_pos: Temperature for positive advantages (τ_pos), default 1.0
tau_neg: Temperature for negative advantages (τ_neg), default 1.05, should be >= tau_pos
loss_agg_mode: Mode for aggregating loss across tokens
"""
super().__init__(backend=backend)
self.tau_pos = tau_pos
self.tau_neg = tau_neg
self.loss_agg_mode = loss_agg_mode

# Validate that tau_neg > tau_pos for stability
assert self.tau_neg >= self.tau_pos, (
f"tau_neg ({self.tau_neg}) should be >= tau_pos ({self.tau_pos}) "
"for better training stability"
)

def soft_gate_function(self, ratio: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
"""Compute the soft gate function f_{i,t}(x).

The soft gate function is defined as:
f_{i,t}(x) = σ(τ_{i,t} * (x - 1)) * 4 / τ_{i,t}

where:
- σ is the sigmoid function
- τ_{i,t} is the asymmetric temperature (tau_pos or tau_neg)
- x is the importance sampling ratio r_{i,t}(θ)

Args:
ratio: Token-level importance sampling ratio r_{i,t}(θ)
advantages: Normalized advantage function Â_i (same for all tokens in a sequence)

Returns:
The soft gate values for each token
"""
# Select temperature based on advantage sign
# tau_i,t = tau_pos if A_i > 0, else tau_neg
tau = torch.where(
advantages > 0,
torch.tensor(self.tau_pos, device=ratio.device, dtype=ratio.dtype),
torch.tensor(self.tau_neg, device=ratio.device, dtype=ratio.dtype),
)

# Compute sigmoid(tau * (ratio - 1))
sigmoid_input = tau * (ratio - 1)
sigmoid_output = torch.sigmoid(sigmoid_input)

# Compute the soft gate: sigma(tau * (x - 1)) * 4 / tau
soft_gate = sigmoid_output * (4.0 / tau)

return soft_gate

def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""Compute SAPO policy loss.

The SAPO objective function is:
J(θ) = E[1/G Σ 1/|y_i| Σ f_{i,t}(r_{i,t}(θ)) * Â_i]

We minimize the negative of this objective.

Args:
logprob: Log probabilities from current policy π_θ
old_logprob: Log probabilities from old policy π_{θ_old}
action_mask: Mask indicating valid tokens
advantages: Group-normalized advantage function

Returns:
loss: The computed policy loss (negative of objective)
metrics: Dictionary of metrics for logging
"""
# Compute token-level importance sampling ratio
# r_{i,t}(θ) = π_θ(y_{i,t}|q, y_{i,<t}) / π_{θ_old}(y_{i,t}|q, y_{i,<t})
negative_approx_kl = logprob - old_logprob
ratio = torch.exp(negative_approx_kl)

# Compute approximate KL divergence for monitoring
ppo_kl = masked_mean(-negative_approx_kl, action_mask)

# Compute soft gate function
soft_gate = self.soft_gate_function(ratio, advantages)

# SAPO loss: -E[f_{i,t}(r_{i,t}) * Â_i]
# We multiply by logprob to get the policy gradient
# The gradient of log π_θ gives us the policy gradient direction
sapo_loss = -advantages * soft_gate.detach() * logprob

# Aggregate loss across tokens
loss = aggregate_loss(sapo_loss, action_mask, loss_agg_mode=self.loss_agg_mode)

# Compute metrics for logging
avg_soft_gate = masked_mean(soft_gate, action_mask)
avg_ratio = masked_mean(ratio, action_mask)

# Compute fraction of tokens with positive/negative advantages
pos_adv_frac = masked_mean((advantages > 0).float(), action_mask)

metrics = {
"sapo_loss": loss.detach().item(),
"ppo_kl": ppo_kl.detach().item(),
"avg_soft_gate": avg_soft_gate.detach().item(),
"avg_ratio": avg_ratio.detach().item(),
"pos_adv_frac": pos_adv_frac.detach().item(),
}

return loss, metrics

@classmethod
def default_args(cls) -> Dict:
"""Get default initialization arguments for SAPO.

Default configuration (from the SAPO paper):
- tau_pos: 1.0 (temperature for positive advantages)
- tau_neg: 1.05 (temperature for negative advantages)
- loss_agg_mode: "token-mean" (average over tokens)

The asymmetric temperatures (tau_neg > tau_pos) help stabilize training
by more aggressively suppressing updates from tokens with negative advantages.

Returns:
Dictionary of default arguments
"""
return {
"tau_pos": 1.0,
"tau_neg": 1.05,
"loss_agg_mode": "token-mean",
}