Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/algorithm/policy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,33 @@ def test_mix_policy_loss(self):
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))

def test_ppo_policy_loss_with_sequence_masking(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn_args["enable_sequence_masking"] = True
policy_loss_fn_args["delta"] = 0.1
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
ppo_loss_masked = torch.tensor(0.22175675630569458)
pg_clipfrac = torch.tensor(0.3541666567325592)
ppo_kl = torch.tensor(-0.21663446724414825)
pg_clipfrac_lower = torch.tensor(0.0625)
masked_tokens = torch.tensor(0.16666666666631944)
mean_sequence_kl = torch.tensor(-0.21027061343193054)
self.assertTrue(torch.allclose(loss, ppo_loss_masked))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_masked))
self.assertTrue(
torch.allclose(torch.tensor(metrics["pg_clipfrac_lower"]), pg_clipfrac_lower)
)
self.assertTrue(
torch.allclose(torch.tensor(metrics["seq_mask/masked_tokens"]), masked_tokens)
)
self.assertTrue(
torch.allclose(torch.tensor(metrics["seq_mask/mean_sequence_kl"]), mean_sequence_kl)
)

def test_sapo_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("sapo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
Expand Down
39 changes: 39 additions & 0 deletions trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def __init__(
clip_range_high: Optional[float] = None,
clip_ratio_c: float = 3.0,
loss_agg_mode: Optional[str] = "token-mean",
enable_sequence_masking: bool = False,
delta: float = 0.1,
) -> None:
super().__init__(backend=backend)
if clip_range_low is None:
Expand All @@ -36,6 +38,8 @@ def __init__(
assert self.clip_range_low is not None, "clip_range_low must be specified."
assert self.clip_range_high is not None, "clip_range_high must be specified."
self.loss_agg_mode = loss_agg_mode
self.enable_sequence_masking = enable_sequence_masking
self.delta = delta

def __call__( # type: ignore
self,
Expand All @@ -51,6 +55,33 @@ def __call__( # type: ignore
ratio = torch.exp(negative_approx_kl)
ppo_kl = masked_mean(-negative_approx_kl, action_mask)

# Compute sequence masking if enabled
sequence_mask = torch.ones_like(advantages)
if self.enable_sequence_masking:
# Compute sequence-level KL divergence: mean KL per sequence
# Shape: (batch_size, seq_len) -> (batch_size,)
kl_per_token = -negative_approx_kl # KL divergence per token
sequence_kl = (kl_per_token * action_mask).sum(dim=-1) / (
action_mask.sum(dim=-1) + 1e-10
)

# Create mask: mask out tokens with negative advantages when sequence KL is high
# Token-level advantage check: (batch_size, seq_len)
has_negative_advantage = advantages < 0
# Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len)
exceeds_kl_threshold = (sequence_kl > self.delta).unsqueeze(-1).expand_as(advantages)
# Mask tokens that are both negative advantage AND in high-KL sequences
should_mask = has_negative_advantage & exceeds_kl_threshold
sequence_mask = (~should_mask).float()

metrics_seq_mask = {
"seq_mask/masked_tokens": should_mask.float().sum().item()
/ (action_mask.sum().item() + 1e-10),
"seq_mask/mean_sequence_kl": sequence_kl.mean().detach().item(),
}
else:
metrics_seq_mask = {}

pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
Expand All @@ -66,13 +97,19 @@ def __call__( # type: ignore
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask
)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

# Apply sequence mask to the losses
if self.enable_sequence_masking:
pg_losses = pg_losses * sequence_mask

pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
metrics = {
"pg_clipfrac": pg_clip_frac.detach().item(),
"pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
"ppo_kl": ppo_kl.detach().item(),
"pg_loss": pg_loss.detach().item(),
}
metrics.update(metrics_seq_mask)
return pg_loss, metrics

@classmethod
Expand All @@ -81,4 +118,6 @@ def default_args(cls) -> Dict:
"clip_range": 0.2,
"clip_ratio_c": 3.0,
"loss_agg_mode": "token-mean",
"enable_sequence_masking": False,
"delta": 0.1,
}