Skip to content

Commit a6d0b6f

Browse files
committed
truncate based on the sign of advantage after ratio clip
1 parent 06023d8 commit a6d0b6f

File tree

4 files changed

+73
-139
lines changed

4 files changed

+73
-139
lines changed

tests/algorithm/policy_loss_test.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,6 @@ def setUp(self):
3030
}
3131
)
3232

33-
def test_dcppo_policy_loss(self):
34-
policy_loss_fn_cls = POLICY_LOSS_FN.get("dcppo")
35-
policy_loss_fn_args = policy_loss_fn_cls.default_args()
36-
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
37-
loss, metrics = policy_loss_fn(logprob=self.logprob, **self.input_data.batch)
38-
dcppo_loss = torch.tensor(0.26889559626579285)
39-
pg_clipfrac = torch.tensor(0.3541666567325592)
40-
pg_clipfrac_lower = torch.tensor(0.0625)
41-
ppo_kl = torch.tensor(-0.21663446724414825)
42-
self.assertTrue(torch.allclose(loss, dcppo_loss))
43-
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
44-
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac_lower"]), pg_clipfrac_lower))
45-
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
46-
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), dcppo_loss))
47-
48-
4933
def test_ppo_policy_loss(self):
5034
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
5135
policy_loss_fn_args = policy_loss_fn_cls.default_args()
@@ -130,7 +114,8 @@ def test_ppo_policy_loss_with_truncate_is(self):
130114
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
131115
policy_loss_fn_args = policy_loss_fn_cls.default_args()
132116
# Enable truncate large IS with default bounds [0.0, 2.0]
133-
policy_loss_fn_args["truncate_large_is"] = True
117+
policy_loss_fn_args["truncate_adv_pos_is"] = True
118+
policy_loss_fn_args["truncate_adv_neg_is"] = True
134119
policy_loss_fn_args["truncate_is_range_low"] = 0.0
135120
policy_loss_fn_args["truncate_is_range_high"] = 2.0
136121
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
@@ -140,16 +125,23 @@ def test_ppo_policy_loss_with_truncate_is(self):
140125
ppo_loss_truncated = torch.tensor(0.2230827361345291)
141126
pg_clipfrac_truncated = torch.tensor(0.3541666567325592)
142127
ppo_kl_truncated = torch.tensor(-0.21663446724414825)
143-
is_truncate_frac_expected = torch.tensor(0.2708333432674408)
128+
is_truncate_frac_pos_expected = torch.tensor(0.0)
129+
is_truncate_frac_neg_expected = torch.tensor(0.1041666641831398)
144130

145131
self.assertTrue(torch.allclose(loss, ppo_loss_truncated))
146132
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_truncated))
147133
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_truncated))
148134
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated))
149135
# Check that IS truncation metric is present and has expected value
150-
self.assertIn("is_truncate_frac", metrics)
136+
self.assertIn("is_truncate_frac_pos", metrics)
137+
self.assertIn("is_truncate_frac_neg", metrics)
138+
self.assertTrue(
139+
torch.allclose(torch.tensor(metrics["is_truncate_frac_pos"]), is_truncate_frac_pos_expected)
140+
)
151141
self.assertTrue(
152-
torch.allclose(torch.tensor(metrics["is_truncate_frac"]), is_truncate_frac_expected)
142+
torch.allclose(torch.tensor(metrics["is_truncate_frac_neg"]), is_truncate_frac_neg_expected)
153143
)
154-
self.assertGreaterEqual(metrics["is_truncate_frac"], 0.0)
155-
self.assertLessEqual(metrics["is_truncate_frac"], 1.0)
144+
self.assertGreaterEqual(metrics["is_truncate_frac_pos"], 0.0)
145+
self.assertLessEqual(metrics["is_truncate_frac_pos"], 1.0)
146+
self.assertGreaterEqual(metrics["is_truncate_frac_neg"], 0.0)
147+
self.assertLessEqual(metrics["is_truncate_frac_neg"], 1.0)

trinity/algorithm/policy_loss_fn/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
1515
from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn
1616
from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn
17-
from trinity.algorithm.policy_loss_fn.dcppo_policy_loss import DualClipPPOPolicyLossFn
1817

1918
__all__ = [
2019
"POLICY_LOSS_FN",
@@ -32,5 +31,4 @@
3231
"SFTPhiLossFn",
3332
"sPPOPolicyLossFn",
3433
"RECPolicyLossFn",
35-
"DualClipPPOPolicyLossFn",
3634
]

trinity/algorithm/policy_loss_fn/dcppo_policy_loss.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def __init__(
2020
clip_range_low: Optional[float] = None,
2121
clip_range_high: Optional[float] = None,
2222
loss_agg_mode: Optional[str] = "token-mean",
23-
truncate_large_is: bool = False,
23+
truncate_adv_pos_is: bool = False,
24+
truncate_adv_neg_is: bool = False,
2425
truncate_is_range_low: Optional[float] = 0.0,
2526
truncate_is_range_high: Optional[float] = 2.0,
2627
) -> None:
@@ -33,8 +34,12 @@ def __init__(
3334
clip_range_low: Lower bound for clipping (1.0 - clip_range_low)
3435
clip_range_high: Upper bound for clipping (1.0 + clip_range_high)
3536
loss_agg_mode: Loss aggregation mode (default: "token-mean")
36-
truncate_large_is: Whether to truncate large importance sampling ratios
37-
to handle calculation discrepancies between rollout and training engines
37+
truncate_adv_pos_is: Whether to truncate large importance sampling ratios
38+
when advantage is positive to handle calculation discrepancies between
39+
rollout and training engines
40+
truncate_adv_neg_is: Whether to truncate large importance sampling ratios
41+
when advantage is negative to handle calculation discrepancies between
42+
rollout and training engines
3843
truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0)
3944
truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0)
4045
"""
@@ -52,17 +57,27 @@ def __init__(
5257
self.loss_agg_mode = loss_agg_mode
5358

5459
# Truncate large IS configuration
55-
self.truncate_large_is = truncate_large_is
56-
if truncate_large_is:
60+
self.truncate_adv_pos_is = truncate_adv_pos_is
61+
self.truncate_adv_neg_is = truncate_adv_neg_is
62+
if truncate_adv_pos_is:
5763
self.truncate_is_range_low = truncate_is_range_low
58-
self.truncate_is_range_high = truncate_is_range_high
5964
assert (
6065
self.truncate_is_range_low is not None
6166
), "truncate_is_range_low must be specified."
67+
assert (
68+
self.truncate_is_range_low >= 0.0
69+
), "truncate_is_range_low must be non-negative."
70+
assert (self.truncate_is_range_low < 1.0-self.clip_range_low
71+
), "truncate_is_range_low must be less than 1.0 - clip_range_low."
72+
if truncate_adv_neg_is:
73+
self.truncate_is_range_high = truncate_is_range_high
6274
assert (
6375
self.truncate_is_range_high is not None
6476
), "truncate_is_range_high must be specified."
65-
assert self.truncate_is_range_low >= 0.0, "truncate_is_range_low must be non-negative."
77+
assert (
78+
self.truncate_is_range_high > 1.0+self.clip_range_high
79+
), "truncate_is_range_high must be greater than clip_range_high + 1.0."
80+
if truncate_adv_pos_is and truncate_adv_neg_is:
6681
assert (
6782
self.truncate_is_range_high > self.truncate_is_range_low
6883
), "truncate_is_range_high must be greater than truncate_is_range_low."
@@ -79,36 +94,54 @@ def __call__( # type: ignore
7994
ratio = torch.exp(negative_approx_kl)
8095
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
8196

82-
# Truncate large IS ratios if enabled
83-
# This helps stabilize training when there are calculation discrepancies between
84-
# rollout and training engines, especially for small probabilities
85-
if self.truncate_large_is:
86-
# Track how often truncation occurs (before actually truncating)
87-
# More efficient than cloning: directly check which values fall outside bounds
88-
ratio_detached = ratio.detach()
89-
is_truncate_frac = masked_mean(
90-
(ratio_detached < self.truncate_is_range_low).float(), action_mask
91-
) + masked_mean((ratio_detached > self.truncate_is_range_high).float(), action_mask)
92-
ratio = torch.clamp(ratio, self.truncate_is_range_low, self.truncate_is_range_high)
93-
94-
pg_losses = -advantages * ratio
97+
# First clipping by clip_range, and calculate pg_clipfrac
98+
pg_losses1 = -advantages * ratio
9599
pg_losses2 = -advantages * torch.clamp(
96100
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
97101
)
102+
pg_losses_clip = torch.maximum(pg_losses1, pg_losses2)
103+
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask)
104+
105+
# After clipped by clip_range, further truncate IS ratios if enabled
106+
# This helps stabilize training when there are calculation discrepancies between
107+
# rollout and training engines, especially for small probabilities
108+
pg_truncfrac_pos, pg_truncfrac_neg = 0.0, 0.0
109+
pg_losses_trunc = pg_losses_clip
110+
111+
# Add IS truncation for positive advantages
112+
if self.truncate_adv_pos_is:
113+
pg_losses_pos_trunc = -advantages * self.truncate_is_range_low
114+
pg_truncfrac_pos = masked_mean(
115+
torch.lt(pg_losses_pos_trunc, pg_losses_trunc) * (advantages > 0).float(),
116+
action_mask,
117+
)
118+
pg_losses_pos = torch.minimum(pg_losses_trunc, pg_losses_pos_trunc)
119+
pg_losses_trunc = torch.where(advantages > 0, pg_losses_pos, pg_losses_trunc)
120+
121+
# Add IS truncation for negative advantages
122+
if self.truncate_adv_neg_is:
123+
pg_losses_neg_trunc = -advantages * self.truncate_is_range_high
124+
pg_truncfrac_neg = masked_mean(
125+
torch.lt(pg_losses_neg_trunc, pg_losses_trunc) * (advantages < 0).float(),
126+
action_mask,
127+
)
128+
pg_losses_neg = torch.minimum(pg_losses_trunc, pg_losses_neg_trunc)
129+
pg_losses_trunc = torch.where(advantages < 0, pg_losses_neg, pg_losses_trunc)
98130

99131
pg_loss = masked_loss(
100-
torch.max(pg_losses, pg_losses2), action_mask, loss_agg_mode=self.loss_agg_mode
132+
pg_losses_trunc, action_mask, loss_agg_mode=self.loss_agg_mode
101133
)
102-
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
103134
metrics = {
104135
"pg_clipfrac": pg_clipfrac.detach().item(),
105136
"ppo_kl": ppo_kl.detach().item(),
106137
"pg_loss": pg_loss.detach().item(),
107138
}
108139

109140
# Add IS truncation metrics if enabled
110-
if self.truncate_large_is:
111-
metrics["is_truncate_frac"] = is_truncate_frac.detach().item()
141+
if self.truncate_adv_pos_is:
142+
metrics["is_truncate_frac_pos"] = pg_truncfrac_pos.detach().item()
143+
if self.truncate_adv_neg_is:
144+
metrics["is_truncate_frac_neg"] = pg_truncfrac_neg.detach().item()
112145

113146
return pg_loss, metrics
114147

@@ -117,7 +150,8 @@ def default_args(cls) -> Dict:
117150
return {
118151
"clip_range": 0.2,
119152
"loss_agg_mode": "token-mean",
120-
"truncate_large_is": False,
153+
"truncate_adv_pos_is": False,
154+
"truncate_adv_neg_is": False,
121155
"truncate_is_range_low": 0.0,
122156
"truncate_is_range_high": 2.0,
123157
}

0 commit comments

Comments
 (0)