Skip to content

Commit b2fd301

Browse files
authored
Add SAPO algorithm (#422)
Co-authored-by: 问昊 <[email protected]>
1 parent b9ff286 commit b2fd301

File tree

4 files changed

+205
-0
lines changed

4 files changed

+205
-0
lines changed

tests/algorithm/policy_loss_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,20 @@ def test_mix_policy_loss(self):
115115
)
116116
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
117117
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))
118+
119+
def test_sapo_policy_loss(self):
120+
policy_loss_fn_cls = POLICY_LOSS_FN.get("sapo")
121+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
122+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
123+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
124+
sapo_loss = torch.tensor(-0.05128994956612587)
125+
ppo_kl = torch.tensor(-0.21663446724414825)
126+
avg_soft_gate = torch.tensor(2.3191137313842773)
127+
avg_ratio = torch.tensor(1.630766749382019)
128+
pos_adv_frac = torch.tensor(0.3958333432674408)
129+
self.assertTrue(torch.allclose(loss, sapo_loss))
130+
self.assertTrue(torch.allclose(torch.tensor(metrics["sapo_loss"]), sapo_loss))
131+
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
132+
self.assertTrue(torch.allclose(torch.tensor(metrics["avg_soft_gate"]), avg_soft_gate))
133+
self.assertTrue(torch.allclose(torch.tensor(metrics["avg_ratio"]), avg_ratio))
134+
self.assertTrue(torch.allclose(torch.tensor(metrics["pos_adv_frac"]), pos_adv_frac))

trinity/algorithm/algorithm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,33 @@ def default_config(cls) -> Dict:
250250
}
251251

252252

253+
@ALGORITHM_TYPE.register_module("sapo")
254+
class SAPOAlgorithm(AlgorithmType):
255+
"""SAPO (Soft Adaptive Policy Optimization) algorithm.
256+
257+
SAPO uses a smooth, temperature-controlled soft gate instead of hard clipping
258+
to stabilize training while maintaining effective learning.
259+
"""
260+
261+
use_critic: bool = False
262+
use_reference: bool = True
263+
compute_advantage_in_trainer: bool = False
264+
can_balance_batch: bool = True
265+
schema: str = "experience"
266+
267+
@classmethod
268+
def default_config(cls) -> Dict:
269+
return {
270+
"repeat_times": 2,
271+
"advantage_fn": "grpo",
272+
"sample_strategy": "default",
273+
"policy_loss_fn": "sapo",
274+
"kl_penalty_fn": "none",
275+
"kl_loss_fn": "k2",
276+
"entropy_loss_fn": "default",
277+
}
278+
279+
253280
@ALGORITHM_TYPE.register_module("mix")
254281
class MIXAlgorithm(AlgorithmType):
255282
"""MIX algorithm."""

trinity/algorithm/policy_loss_fn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
1212
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
1313
from trinity.algorithm.policy_loss_fn.rec_policy_loss import RECPolicyLossFn
14+
from trinity.algorithm.policy_loss_fn.sapo_policy_loss import SAPOPolicyLossFn
1415
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
1516
from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn
1617
from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn
@@ -31,4 +32,5 @@
3132
"SFTPhiLossFn",
3233
"sPPOPolicyLossFn",
3334
"RECPolicyLossFn",
35+
"SAPOPolicyLossFn",
3436
]
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""SAPO policy loss function.
2+
Soft Adaptive Policy Optimization (SAPO) is a reinforcement learning algorithm
3+
that uses a smooth, temperature-controlled soft gate instead of hard clipping.
4+
5+
Refer to the SAPO paper for details. https://arxiv.org/abs/2511.20347
6+
"""
7+
8+
from typing import Dict, Tuple
9+
10+
import torch
11+
12+
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
13+
from trinity.algorithm.utils import aggregate_loss, masked_mean
14+
15+
16+
@POLICY_LOSS_FN.register_module("sapo")
17+
class SAPOPolicyLossFn(PolicyLossFn):
18+
def __init__(
19+
self,
20+
backend: str = "verl",
21+
tau_pos: float = 1.0,
22+
tau_neg: float = 1.05,
23+
loss_agg_mode: str = "token-mean",
24+
) -> None:
25+
"""Initialize SAPO policy loss function.
26+
27+
Args:
28+
backend: The training framework/backend to use (e.g., "verl")
29+
tau_pos: Temperature for positive advantages (τ_pos), default 1.0
30+
tau_neg: Temperature for negative advantages (τ_neg), default 1.05, should be >= tau_pos
31+
loss_agg_mode: Mode for aggregating loss across tokens
32+
"""
33+
super().__init__(backend=backend)
34+
self.tau_pos = tau_pos
35+
self.tau_neg = tau_neg
36+
self.loss_agg_mode = loss_agg_mode
37+
38+
# Validate that tau_neg > tau_pos for stability
39+
assert self.tau_neg >= self.tau_pos, (
40+
f"tau_neg ({self.tau_neg}) should be >= tau_pos ({self.tau_pos}) "
41+
"for better training stability"
42+
)
43+
44+
def soft_gate_function(self, ratio: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
45+
"""Compute the soft gate function f_{i,t}(x).
46+
47+
The soft gate function is defined as:
48+
f_{i,t}(x) = σ(τ_{i,t} * (x - 1)) * 4 / τ_{i,t}
49+
50+
where:
51+
- σ is the sigmoid function
52+
- τ_{i,t} is the asymmetric temperature (tau_pos or tau_neg)
53+
- x is the importance sampling ratio r_{i,t}(θ)
54+
55+
Args:
56+
ratio: Token-level importance sampling ratio r_{i,t}(θ)
57+
advantages: Normalized advantage function Â_i (same for all tokens in a sequence)
58+
59+
Returns:
60+
The soft gate values for each token
61+
"""
62+
# Select temperature based on advantage sign
63+
# tau_i,t = tau_pos if A_i > 0, else tau_neg
64+
tau = torch.where(
65+
advantages > 0,
66+
torch.tensor(self.tau_pos, device=ratio.device, dtype=ratio.dtype),
67+
torch.tensor(self.tau_neg, device=ratio.device, dtype=ratio.dtype),
68+
)
69+
70+
# Compute sigmoid(tau * (ratio - 1))
71+
sigmoid_input = tau * (ratio - 1)
72+
sigmoid_output = torch.sigmoid(sigmoid_input)
73+
74+
# Compute the soft gate: sigma(tau * (x - 1)) * 4 / tau
75+
soft_gate = sigmoid_output * (4.0 / tau)
76+
77+
return soft_gate
78+
79+
def __call__( # type: ignore
80+
self,
81+
logprob: torch.Tensor,
82+
old_logprob: torch.Tensor,
83+
action_mask: torch.Tensor,
84+
advantages: torch.Tensor,
85+
**kwargs,
86+
) -> Tuple[torch.Tensor, Dict]:
87+
"""Compute SAPO policy loss.
88+
89+
The SAPO objective function is:
90+
J(θ) = E[1/G Σ 1/|y_i| Σ f_{i,t}(r_{i,t}(θ)) * Â_i]
91+
92+
We minimize the negative of this objective.
93+
94+
Args:
95+
logprob: Log probabilities from current policy π_θ
96+
old_logprob: Log probabilities from old policy π_{θ_old}
97+
action_mask: Mask indicating valid tokens
98+
advantages: Group-normalized advantage function
99+
100+
Returns:
101+
loss: The computed policy loss (negative of objective)
102+
metrics: Dictionary of metrics for logging
103+
"""
104+
# Compute token-level importance sampling ratio
105+
# r_{i,t}(θ) = π_θ(y_{i,t}|q, y_{i,<t}) / π_{θ_old}(y_{i,t}|q, y_{i,<t})
106+
negative_approx_kl = logprob - old_logprob
107+
ratio = torch.exp(negative_approx_kl)
108+
109+
# Compute approximate KL divergence for monitoring
110+
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
111+
112+
# Compute soft gate function
113+
soft_gate = self.soft_gate_function(ratio, advantages)
114+
115+
# SAPO loss: -E[f_{i,t}(r_{i,t}) * Â_i]
116+
# We multiply by logprob to get the policy gradient
117+
# The gradient of log π_θ gives us the policy gradient direction
118+
sapo_loss = -advantages * soft_gate.detach() * logprob
119+
120+
# Aggregate loss across tokens
121+
loss = aggregate_loss(sapo_loss, action_mask, loss_agg_mode=self.loss_agg_mode)
122+
123+
# Compute metrics for logging
124+
avg_soft_gate = masked_mean(soft_gate, action_mask)
125+
avg_ratio = masked_mean(ratio, action_mask)
126+
127+
# Compute fraction of tokens with positive/negative advantages
128+
pos_adv_frac = masked_mean((advantages > 0).float(), action_mask)
129+
130+
metrics = {
131+
"sapo_loss": loss.detach().item(),
132+
"ppo_kl": ppo_kl.detach().item(),
133+
"avg_soft_gate": avg_soft_gate.detach().item(),
134+
"avg_ratio": avg_ratio.detach().item(),
135+
"pos_adv_frac": pos_adv_frac.detach().item(),
136+
}
137+
138+
return loss, metrics
139+
140+
@classmethod
141+
def default_args(cls) -> Dict:
142+
"""Get default initialization arguments for SAPO.
143+
144+
Default configuration (from the SAPO paper):
145+
- tau_pos: 1.0 (temperature for positive advantages)
146+
- tau_neg: 1.05 (temperature for negative advantages)
147+
- loss_agg_mode: "token-mean" (average over tokens)
148+
149+
The asymmetric temperatures (tau_neg > tau_pos) help stabilize training
150+
by more aggressively suppressing updates from tokens with negative advantages.
151+
152+
Returns:
153+
Dictionary of default arguments
154+
"""
155+
return {
156+
"tau_pos": 1.0,
157+
"tau_neg": 1.05,
158+
"loss_agg_mode": "token-mean",
159+
}

0 commit comments

Comments
 (0)