Skip to content

Commit 613194d

Browse files
authored
Add sequence mask for grpo (agentscope-ai#420)
Co-authored-by: 问昊 <[email protected]>
1 parent 8648d4d commit 613194d

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

tests/algorithm/policy_loss_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,33 @@ def test_mix_policy_loss(self):
116116
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
117117
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))
118118

119+
def test_ppo_policy_loss_with_sequence_masking(self):
120+
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
121+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
122+
policy_loss_fn_args["enable_sequence_masking"] = True
123+
policy_loss_fn_args["delta_sequence_masking"] = 0.1
124+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
125+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
126+
ppo_loss_masked = torch.tensor(0.22175675630569458)
127+
pg_clipfrac = torch.tensor(0.3541666567325592)
128+
ppo_kl = torch.tensor(-0.21663446724414825)
129+
pg_clipfrac_lower = torch.tensor(0.0625)
130+
masked_tokens = torch.tensor(0.16666666666631944)
131+
mean_sequence_kl = torch.tensor(-0.21027061343193054)
132+
self.assertTrue(torch.allclose(loss, ppo_loss_masked))
133+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
134+
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
135+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_masked))
136+
self.assertTrue(
137+
torch.allclose(torch.tensor(metrics["pg_clipfrac_lower"]), pg_clipfrac_lower)
138+
)
139+
self.assertTrue(
140+
torch.allclose(torch.tensor(metrics["seq_mask/masked_tokens"]), masked_tokens)
141+
)
142+
self.assertTrue(
143+
torch.allclose(torch.tensor(metrics["seq_mask/mean_sequence_kl"]), mean_sequence_kl)
144+
)
145+
119146
def test_sapo_policy_loss(self):
120147
policy_loss_fn_cls = POLICY_LOSS_FN.get("sapo")
121148
policy_loss_fn_args = policy_loss_fn_cls.default_args()

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def __init__(
2121
clip_range_high: Optional[float] = None,
2222
clip_ratio_c: float = 3.0,
2323
loss_agg_mode: Optional[str] = "token-mean",
24+
enable_sequence_masking: bool = False, # introduced in DeepseekV3.2
25+
delta_sequence_masking: float = 0.1,
2426
) -> None:
2527
super().__init__(backend=backend)
2628
if clip_range_low is None:
@@ -36,6 +38,8 @@ def __init__(
3638
assert self.clip_range_low is not None, "clip_range_low must be specified."
3739
assert self.clip_range_high is not None, "clip_range_high must be specified."
3840
self.loss_agg_mode = loss_agg_mode
41+
self.enable_sequence_masking = enable_sequence_masking
42+
self.delta_sequence_masking = delta_sequence_masking
3943

4044
def __call__( # type: ignore
4145
self,
@@ -66,13 +70,45 @@ def __call__( # type: ignore
6670
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask
6771
)
6872
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
73+
74+
# Apply sequence masking if enabled
75+
if self.enable_sequence_masking:
76+
# Compute sequence-level KL divergence: mean KL per sequence
77+
# Shape: (batch_size, seq_len) -> (batch_size,)
78+
kl_per_token = -negative_approx_kl # KL divergence per token
79+
sequence_kl = (kl_per_token * action_mask).sum(dim=-1) / (
80+
action_mask.sum(dim=-1) + 1e-10
81+
)
82+
83+
# Create mask: mask out tokens with negative advantages when sequence KL is high
84+
# Token-level advantage check: (batch_size, seq_len)
85+
has_negative_advantage = advantages < 0
86+
# Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len)
87+
exceeds_kl_threshold = (
88+
(sequence_kl > self.delta_sequence_masking).unsqueeze(-1).expand_as(advantages)
89+
)
90+
# Mask tokens that are both negative advantage AND in high-KL sequences
91+
should_mask = has_negative_advantage & exceeds_kl_threshold
92+
sequence_mask = (~should_mask).float()
93+
94+
# Apply sequence mask to the losses
95+
pg_losses = pg_losses * sequence_mask
96+
97+
metrics_seq_mask = {
98+
"seq_mask/masked_tokens": should_mask.float().sum().item()
99+
/ (action_mask.sum().item() + 1e-10),
100+
"seq_mask/mean_sequence_kl": sequence_kl.mean().detach().item(),
101+
}
102+
69103
pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
70104
metrics = {
71105
"pg_clipfrac": pg_clip_frac.detach().item(),
72106
"pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
73107
"ppo_kl": ppo_kl.detach().item(),
74108
"pg_loss": pg_loss.detach().item(),
75109
}
110+
if self.enable_sequence_masking:
111+
metrics.update(metrics_seq_mask)
76112
return pg_loss, metrics
77113

78114
@classmethod
@@ -81,4 +117,6 @@ def default_args(cls) -> Dict:
81117
"clip_range": 0.2,
82118
"clip_ratio_c": 3.0,
83119
"loss_agg_mode": "token-mean",
120+
"enable_sequence_masking": False,
121+
"delta_sequence_masking": 0.1,
84122
}

0 commit comments

Comments
 (0)