Skip to content

Commit 221b075

Browse files
committed
add sequence mask for grpo
1 parent 3861859 commit 221b075

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

tests/algorithm/policy_loss_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,28 @@ def test_mix_policy_loss(self):
115115
)
116116
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
117117
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))
118+
119+
def test_ppo_policy_loss_with_sequence_masking(self):
120+
"""Test PPO policy loss with sequence masking enabled"""
121+
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
122+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
123+
policy_loss_fn_args["enable_sequence_masking"] = True
124+
policy_loss_fn_args["delta"] = 0.1
125+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
126+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
127+
128+
# Test that sequence masking metrics are present
129+
self.assertIn("seq_mask/masked_tokens", metrics)
130+
self.assertIn("seq_mask/mean_sequence_kl", metrics)
131+
132+
# Test that masked_tokens is between 0 and 1
133+
self.assertGreaterEqual(metrics["seq_mask/masked_tokens"], 0.0)
134+
self.assertLessEqual(metrics["seq_mask/masked_tokens"], 1.0)
135+
136+
# Test that loss is different from non-masked version (if masking occurred)
137+
policy_loss_fn_no_mask = policy_loss_fn_cls(**policy_loss_fn_cls.default_args())
138+
loss_no_mask, _ = policy_loss_fn_no_mask(log_prob=self.logprob, **self.input_data.batch)
139+
140+
# Loss should be different if tokens were masked
141+
if metrics["seq_mask/masked_tokens"] > 0:
142+
self.assertFalse(torch.allclose(loss, loss_no_mask))

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def __init__(
2121
clip_range_high: Optional[float] = None,
2222
clip_ratio_c: float = 3.0,
2323
loss_agg_mode: Optional[str] = "token-mean",
24+
enable_sequence_masking: bool = False,
25+
delta: float = 0.1,
2426
) -> None:
2527
super().__init__(backend=backend)
2628
if clip_range_low is None:
@@ -36,6 +38,8 @@ def __init__(
3638
assert self.clip_range_low is not None, "clip_range_low must be specified."
3739
assert self.clip_range_high is not None, "clip_range_high must be specified."
3840
self.loss_agg_mode = loss_agg_mode
41+
self.enable_sequence_masking = enable_sequence_masking
42+
self.delta = delta
3943

4044
def __call__( # type: ignore
4145
self,
@@ -51,6 +55,33 @@ def __call__( # type: ignore
5155
ratio = torch.exp(negative_approx_kl)
5256
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
5357

58+
# Compute sequence masking if enabled
59+
sequence_mask = torch.ones_like(advantages)
60+
if self.enable_sequence_masking:
61+
# Compute sequence-level KL divergence: mean KL per sequence
62+
# Shape: (batch_size, seq_len) -> (batch_size,)
63+
kl_per_token = -negative_approx_kl # KL divergence per token
64+
sequence_kl = (kl_per_token * action_mask).sum(dim=-1) / (
65+
action_mask.sum(dim=-1) + 1e-10
66+
)
67+
68+
# Create mask: mask out tokens with negative advantages when sequence KL is high
69+
# Token-level advantage check: (batch_size, seq_len)
70+
has_negative_advantage = advantages < 0
71+
# Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len)
72+
exceeds_kl_threshold = (sequence_kl > self.delta).unsqueeze(-1).expand_as(advantages)
73+
# Mask tokens that are both negative advantage AND in high-KL sequences
74+
should_mask = has_negative_advantage & exceeds_kl_threshold
75+
sequence_mask = (~should_mask).float()
76+
77+
metrics_seq_mask = {
78+
"seq_mask/masked_tokens": should_mask.float().sum().item()
79+
/ (action_mask.sum().item() + 1e-10),
80+
"seq_mask/mean_sequence_kl": sequence_kl.mean().detach().item(),
81+
}
82+
else:
83+
metrics_seq_mask = {}
84+
5485
pg_losses1 = -advantages * ratio
5586
pg_losses2 = -advantages * torch.clamp(
5687
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
@@ -66,13 +97,18 @@ def __call__( # type: ignore
6697
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask
6798
)
6899
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
100+
101+
# Apply sequence mask to the losses
102+
pg_losses = pg_losses * sequence_mask
103+
69104
pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
70105
metrics = {
71106
"pg_clipfrac": pg_clip_frac.detach().item(),
72107
"pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
73108
"ppo_kl": ppo_kl.detach().item(),
74109
"pg_loss": pg_loss.detach().item(),
75110
}
111+
metrics.update(metrics_seq_mask)
76112
return pg_loss, metrics
77113

78114
@classmethod
@@ -81,4 +117,6 @@ def default_args(cls) -> Dict:
81117
"clip_range": 0.2,
82118
"clip_ratio_c": 3.0,
83119
"loss_agg_mode": "token-mean",
120+
"enable_sequence_masking": False,
121+
"delta": 0.1,
84122
}

0 commit comments

Comments
 (0)