Skip to content

Commit ec4d46c

Browse files
garyzhang99lehaoqu
authored andcommitted
PR#334 &
truncate based on the sign of advantage after ratio clip
1 parent 7964099 commit ec4d46c

File tree

2 files changed

+127
-3
lines changed

2 files changed

+127
-3
lines changed

tests/algorithm/policy_loss_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,40 @@ def test_mix_policy_loss(self):
108108
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss))
109109
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
110110
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))
111+
112+
def test_ppo_policy_loss_with_truncate_is(self):
113+
"""Test PPO policy loss with truncate large IS enabled."""
114+
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
115+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
116+
# Enable truncate large IS with default bounds [0.0, 2.0]
117+
policy_loss_fn_args["truncate_adv_pos_is"] = True
118+
policy_loss_fn_args["truncate_adv_neg_is"] = True
119+
policy_loss_fn_args["truncate_is_range_low"] = 0.0
120+
policy_loss_fn_args["truncate_is_range_high"] = 2.0
121+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
122+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
123+
124+
# Expected values with IS truncation enabled (range: [0.0, 2.0])
125+
ppo_loss_truncated = torch.tensor(0.2230827361345291)
126+
pg_clipfrac_truncated = torch.tensor(0.3541666567325592)
127+
ppo_kl_truncated = torch.tensor(-0.21663446724414825)
128+
is_truncate_frac_pos_expected = torch.tensor(0.0)
129+
is_truncate_frac_neg_expected = torch.tensor(0.1041666641831398)
130+
131+
self.assertTrue(torch.allclose(loss, ppo_loss_truncated))
132+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_truncated))
133+
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_truncated))
134+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated))
135+
# Check that IS truncation metric is present and has expected value
136+
self.assertIn("is_truncate_frac_pos", metrics)
137+
self.assertIn("is_truncate_frac_neg", metrics)
138+
self.assertTrue(
139+
torch.allclose(torch.tensor(metrics["is_truncate_frac_pos"]), is_truncate_frac_pos_expected)
140+
)
141+
self.assertTrue(
142+
torch.allclose(torch.tensor(metrics["is_truncate_frac_neg"]), is_truncate_frac_neg_expected)
143+
)
144+
self.assertGreaterEqual(metrics["is_truncate_frac_pos"], 0.0)
145+
self.assertLessEqual(metrics["is_truncate_frac_pos"], 1.0)
146+
self.assertGreaterEqual(metrics["is_truncate_frac_neg"], 0.0)
147+
self.assertLessEqual(metrics["is_truncate_frac_neg"], 1.0)

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,29 @@ def __init__(
2020
clip_range_low: Optional[float] = None,
2121
clip_range_high: Optional[float] = None,
2222
loss_agg_mode: Optional[str] = "token-mean",
23+
truncate_adv_pos_is: bool = False,
24+
truncate_adv_neg_is: bool = False,
25+
truncate_is_range_low: Optional[float] = 0.0,
26+
truncate_is_range_high: Optional[float] = 2.0,
2327
) -> None:
28+
"""
29+
Initialize PPO policy loss function.
30+
31+
Args:
32+
backend: Backend framework (default: "verl")
33+
clip_range: Symmetric clipping range for PPO
34+
clip_range_low: Lower bound for clipping (1.0 - clip_range_low)
35+
clip_range_high: Upper bound for clipping (1.0 + clip_range_high)
36+
loss_agg_mode: Loss aggregation mode (default: "token-mean")
37+
truncate_adv_pos_is: Whether to truncate large importance sampling ratios
38+
when advantage is positive to handle calculation discrepancies between
39+
rollout and training engines
40+
truncate_adv_neg_is: Whether to truncate large importance sampling ratios
41+
when advantage is negative to handle calculation discrepancies between
42+
rollout and training engines
43+
truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0)
44+
truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0)
45+
"""
2446
super().__init__(backend=backend)
2547
if clip_range_low is None:
2648
self.clip_range_low = clip_range
@@ -34,6 +56,32 @@ def __init__(
3456
assert self.clip_range_high is not None, "clip_range_high must be specified."
3557
self.loss_agg_mode = loss_agg_mode
3658

59+
# Truncate large IS configuration
60+
self.truncate_adv_pos_is = truncate_adv_pos_is
61+
self.truncate_adv_neg_is = truncate_adv_neg_is
62+
if truncate_adv_pos_is:
63+
self.truncate_is_range_low = truncate_is_range_low
64+
assert (
65+
self.truncate_is_range_low is not None
66+
), "truncate_is_range_low must be specified."
67+
assert (
68+
self.truncate_is_range_low >= 0.0
69+
), "truncate_is_range_low must be non-negative."
70+
assert (self.truncate_is_range_low < 1.0-self.clip_range_low
71+
), "truncate_is_range_low must be less than 1.0 - clip_range_low."
72+
if truncate_adv_neg_is:
73+
self.truncate_is_range_high = truncate_is_range_high
74+
assert (
75+
self.truncate_is_range_high is not None
76+
), "truncate_is_range_high must be specified."
77+
assert (
78+
self.truncate_is_range_high > 1.0+self.clip_range_high
79+
), "truncate_is_range_high must be greater than clip_range_high + 1.0."
80+
if truncate_adv_pos_is and truncate_adv_neg_is:
81+
assert (
82+
self.truncate_is_range_high > self.truncate_is_range_low
83+
), "truncate_is_range_high must be greater than truncate_is_range_low."
84+
3785
def __call__( # type: ignore
3886
self,
3987
logprob: torch.Tensor,
@@ -46,25 +94,64 @@ def __call__( # type: ignore
4694
ratio = torch.exp(negative_approx_kl)
4795
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
4896

49-
pg_losses = -advantages * ratio
97+
# First clipping by clip_range, and calculate pg_clipfrac
98+
pg_losses1 = -advantages * ratio
5099
pg_losses2 = -advantages * torch.clamp(
51100
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
52101
)
102+
pg_losses_clip = torch.maximum(pg_losses1, pg_losses2)
103+
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask)
104+
105+
# After clipped by clip_range, further truncate IS ratios if enabled
106+
# This helps stabilize training when there are calculation discrepancies between
107+
# rollout and training engines, especially for small probabilities
108+
pg_truncfrac_pos, pg_truncfrac_neg = 0.0, 0.0
109+
pg_losses_trunc = pg_losses_clip
110+
111+
# Add IS truncation for positive advantages
112+
if self.truncate_adv_pos_is:
113+
pg_losses_pos_trunc = -advantages * self.truncate_is_range_low
114+
pg_truncfrac_pos = masked_mean(
115+
torch.lt(pg_losses_pos_trunc, pg_losses_trunc) * (advantages > 0).float(),
116+
action_mask,
117+
)
118+
pg_losses_pos = torch.minimum(pg_losses_trunc, pg_losses_pos_trunc)
119+
pg_losses_trunc = torch.where(advantages > 0, pg_losses_pos, pg_losses_trunc)
120+
121+
# Add IS truncation for negative advantages
122+
if self.truncate_adv_neg_is:
123+
pg_losses_neg_trunc = -advantages * self.truncate_is_range_high
124+
pg_truncfrac_neg = masked_mean(
125+
torch.lt(pg_losses_neg_trunc, pg_losses_trunc) * (advantages < 0).float(),
126+
action_mask,
127+
)
128+
pg_losses_neg = torch.minimum(pg_losses_trunc, pg_losses_neg_trunc)
129+
pg_losses_trunc = torch.where(advantages < 0, pg_losses_neg, pg_losses_trunc)
53130

54131
pg_loss = masked_loss(
55-
torch.max(pg_losses, pg_losses2), action_mask, loss_agg_mode=self.loss_agg_mode
132+
pg_losses_trunc, action_mask, loss_agg_mode=self.loss_agg_mode
56133
)
57-
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
58134
metrics = {
59135
"pg_clipfrac": pg_clipfrac.detach().item(),
60136
"ppo_kl": ppo_kl.detach().item(),
61137
"pg_loss": pg_loss.detach().item(),
62138
}
139+
140+
# Add IS truncation metrics if enabled
141+
if self.truncate_adv_pos_is:
142+
metrics["is_truncate_frac_pos"] = pg_truncfrac_pos.detach().item()
143+
if self.truncate_adv_neg_is:
144+
metrics["is_truncate_frac_neg"] = pg_truncfrac_neg.detach().item()
145+
63146
return pg_loss, metrics
64147

65148
@classmethod
66149
def default_args(cls) -> Dict:
67150
return {
68151
"clip_range": 0.2,
69152
"loss_agg_mode": "token-mean",
153+
"truncate_adv_pos_is": False,
154+
"truncate_adv_neg_is": False,
155+
"truncate_is_range_low": 0.0,
156+
"truncate_is_range_high": 2.0,
70157
}

0 commit comments

Comments
 (0)