diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index cc5d95972f..d4ddbbf87c 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -116,6 +116,33 @@ def test_mix_policy_loss(self): self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) + def test_ppo_policy_loss_with_sequence_masking(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn_args["enable_sequence_masking"] = True + policy_loss_fn_args["delta_sequence_masking"] = 0.1 + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + ppo_loss_masked = torch.tensor(0.22175675630569458) + pg_clipfrac = torch.tensor(0.3541666567325592) + ppo_kl = torch.tensor(-0.21663446724414825) + pg_clipfrac_lower = torch.tensor(0.0625) + masked_tokens = torch.tensor(0.16666666666631944) + mean_sequence_kl = torch.tensor(-0.21027061343193054) + self.assertTrue(torch.allclose(loss, ppo_loss_masked)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_masked)) + self.assertTrue( + torch.allclose(torch.tensor(metrics["pg_clipfrac_lower"]), pg_clipfrac_lower) + ) + self.assertTrue( + torch.allclose(torch.tensor(metrics["seq_mask/masked_tokens"]), masked_tokens) + ) + self.assertTrue( + torch.allclose(torch.tensor(metrics["seq_mask/mean_sequence_kl"]), mean_sequence_kl) + ) + def test_sapo_policy_loss(self): policy_loss_fn_cls = POLICY_LOSS_FN.get("sapo") policy_loss_fn_args = policy_loss_fn_cls.default_args() diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index f2c812a0b5..d5e17a83e9 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -21,6 +21,8 @@ def __init__( clip_range_high: Optional[float] = None, clip_ratio_c: float = 3.0, loss_agg_mode: Optional[str] = "token-mean", + enable_sequence_masking: bool = False, # introduced in DeepseekV3.2 + delta_sequence_masking: float = 0.1, ) -> None: super().__init__(backend=backend) if clip_range_low is None: @@ -36,6 +38,8 @@ def __init__( assert self.clip_range_low is not None, "clip_range_low must be specified." assert self.clip_range_high is not None, "clip_range_high must be specified." self.loss_agg_mode = loss_agg_mode + self.enable_sequence_masking = enable_sequence_masking + self.delta_sequence_masking = delta_sequence_masking def __call__( # type: ignore self, @@ -66,6 +70,36 @@ def __call__( # type: ignore torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask ) pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + + # Apply sequence masking if enabled + if self.enable_sequence_masking: + # Compute sequence-level KL divergence: mean KL per sequence + # Shape: (batch_size, seq_len) -> (batch_size,) + kl_per_token = -negative_approx_kl # KL divergence per token + sequence_kl = (kl_per_token * action_mask).sum(dim=-1) / ( + action_mask.sum(dim=-1) + 1e-10 + ) + + # Create mask: mask out tokens with negative advantages when sequence KL is high + # Token-level advantage check: (batch_size, seq_len) + has_negative_advantage = advantages < 0 + # Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len) + exceeds_kl_threshold = ( + (sequence_kl > self.delta_sequence_masking).unsqueeze(-1).expand_as(advantages) + ) + # Mask tokens that are both negative advantage AND in high-KL sequences + should_mask = has_negative_advantage & exceeds_kl_threshold + sequence_mask = (~should_mask).float() + + # Apply sequence mask to the losses + pg_losses = pg_losses * sequence_mask + + metrics_seq_mask = { + "seq_mask/masked_tokens": should_mask.float().sum().item() + / (action_mask.sum().item() + 1e-10), + "seq_mask/mean_sequence_kl": sequence_kl.mean().detach().item(), + } + pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode) metrics = { "pg_clipfrac": pg_clip_frac.detach().item(), @@ -73,6 +107,8 @@ def __call__( # type: ignore "ppo_kl": ppo_kl.detach().item(), "pg_loss": pg_loss.detach().item(), } + if self.enable_sequence_masking: + metrics.update(metrics_seq_mask) return pg_loss, metrics @classmethod @@ -81,4 +117,6 @@ def default_args(cls) -> Dict: "clip_range": 0.2, "clip_ratio_c": 3.0, "loss_agg_mode": "token-mean", + "enable_sequence_masking": False, + "delta_sequence_masking": 0.1, }