Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/algorithm/policy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,33 @@ def test_mix_policy_loss(self):
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))

def test_ppo_policy_loss_with_sequence_masking(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn_args["enable_sequence_masking"] = True
policy_loss_fn_args["delta_sequence_masking"] = 0.1
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
ppo_loss_masked = torch.tensor(0.22175675630569458)
pg_clipfrac = torch.tensor(0.3541666567325592)
ppo_kl = torch.tensor(-0.21663446724414825)
pg_clipfrac_lower = torch.tensor(0.0625)
masked_tokens = torch.tensor(0.16666666666631944)
mean_sequence_kl = torch.tensor(-0.21027061343193054)
self.assertTrue(torch.allclose(loss, ppo_loss_masked))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_masked))
self.assertTrue(
torch.allclose(torch.tensor(metrics["pg_clipfrac_lower"]), pg_clipfrac_lower)
)
self.assertTrue(
torch.allclose(torch.tensor(metrics["seq_mask/masked_tokens"]), masked_tokens)
)
self.assertTrue(
torch.allclose(torch.tensor(metrics["seq_mask/mean_sequence_kl"]), mean_sequence_kl)
)

def test_sapo_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("sapo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
Expand Down
38 changes: 38 additions & 0 deletions trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def __init__(
clip_range_high: Optional[float] = None,
clip_ratio_c: float = 3.0,
loss_agg_mode: Optional[str] = "token-mean",
enable_sequence_masking: bool = False, # introduced in DeepseekV3.2
delta_sequence_masking: float = 0.1,
) -> None:
super().__init__(backend=backend)
if clip_range_low is None:
Expand All @@ -36,6 +38,8 @@ def __init__(
assert self.clip_range_low is not None, "clip_range_low must be specified."
assert self.clip_range_high is not None, "clip_range_high must be specified."
self.loss_agg_mode = loss_agg_mode
self.enable_sequence_masking = enable_sequence_masking
self.delta_sequence_masking = delta_sequence_masking

def __call__( # type: ignore
self,
Expand Down Expand Up @@ -66,13 +70,45 @@ def __call__( # type: ignore
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask
)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

# Apply sequence masking if enabled
if self.enable_sequence_masking:
# Compute sequence-level KL divergence: mean KL per sequence
# Shape: (batch_size, seq_len) -> (batch_size,)
kl_per_token = -negative_approx_kl # KL divergence per token
sequence_kl = (kl_per_token * action_mask).sum(dim=-1) / (
action_mask.sum(dim=-1) + 1e-10
)

# Create mask: mask out tokens with negative advantages when sequence KL is high
# Token-level advantage check: (batch_size, seq_len)
has_negative_advantage = advantages < 0
# Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len)
exceeds_kl_threshold = (
(sequence_kl > self.delta_sequence_masking).unsqueeze(-1).expand_as(advantages)
)
# Mask tokens that are both negative advantage AND in high-KL sequences
should_mask = has_negative_advantage & exceeds_kl_threshold
sequence_mask = (~should_mask).float()

# Apply sequence mask to the losses
pg_losses = pg_losses * sequence_mask

metrics_seq_mask = {
"seq_mask/masked_tokens": should_mask.float().sum().item()
/ (action_mask.sum().item() + 1e-10),
"seq_mask/mean_sequence_kl": sequence_kl.mean().detach().item(),
}

pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
metrics = {
"pg_clipfrac": pg_clip_frac.detach().item(),
"pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
"ppo_kl": ppo_kl.detach().item(),
"pg_loss": pg_loss.detach().item(),
}
if self.enable_sequence_masking:
metrics.update(metrics_seq_mask)
return pg_loss, metrics

@classmethod
Expand All @@ -81,4 +117,6 @@ def default_args(cls) -> Dict:
"clip_range": 0.2,
"clip_ratio_c": 3.0,
"loss_agg_mode": "token-mean",
"enable_sequence_masking": False,
"delta_sequence_masking": 0.1,
}