Skip to content

Commit 7763bdd

Browse files
committed
clean code as suggested
1 parent e221c8a commit 7763bdd

File tree

2 files changed

+30
-31
lines changed

2 files changed

+30
-31
lines changed

tests/algorithm/policy_loss_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_ppo_policy_loss_with_sequence_masking(self):
120120
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
121121
policy_loss_fn_args = policy_loss_fn_cls.default_args()
122122
policy_loss_fn_args["enable_sequence_masking"] = True
123-
policy_loss_fn_args["delta"] = 0.1
123+
policy_loss_fn_args["delta_sequence_masking"] = 0.1
124124
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
125125
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
126126
ppo_loss_masked = torch.tensor(0.22175675630569458)

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def __init__(
2121
clip_range_high: Optional[float] = None,
2222
clip_ratio_c: float = 3.0,
2323
loss_agg_mode: Optional[str] = "token-mean",
24-
enable_sequence_masking: bool = False,
25-
delta: float = 0.1,
24+
enable_sequence_masking: bool = False, # introduced in DeepseekV3.2
25+
delta_sequence_masking: float = 0.1,
2626
) -> None:
2727
super().__init__(backend=backend)
2828
if clip_range_low is None:
@@ -39,7 +39,7 @@ def __init__(
3939
assert self.clip_range_high is not None, "clip_range_high must be specified."
4040
self.loss_agg_mode = loss_agg_mode
4141
self.enable_sequence_masking = enable_sequence_masking
42-
self.delta = delta
42+
self.delta_sequence_masking = delta_sequence_masking
4343

4444
def __call__( # type: ignore
4545
self,
@@ -55,8 +55,23 @@ def __call__( # type: ignore
5555
ratio = torch.exp(negative_approx_kl)
5656
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
5757

58-
# Compute sequence masking if enabled
59-
sequence_mask = torch.ones_like(advantages)
58+
pg_losses1 = -advantages * ratio
59+
pg_losses2 = -advantages * torch.clamp(
60+
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
61+
)
62+
63+
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
64+
65+
pg_clip_frac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask)
66+
67+
pg_losses3 = -advantages * self.clip_ratio_c
68+
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
69+
pg_clipfrac_lower = masked_mean(
70+
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask
71+
)
72+
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
73+
74+
# Apply sequence masking if enabled
6075
if self.enable_sequence_masking:
6176
# Compute sequence-level KL divergence: mean KL per sequence
6277
# Shape: (batch_size, seq_len) -> (batch_size,)
@@ -69,38 +84,21 @@ def __call__( # type: ignore
6984
# Token-level advantage check: (batch_size, seq_len)
7085
has_negative_advantage = advantages < 0
7186
# Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len)
72-
exceeds_kl_threshold = (sequence_kl > self.delta).unsqueeze(-1).expand_as(advantages)
87+
exceeds_kl_threshold = (
88+
(sequence_kl > self.delta_sequence_masking).unsqueeze(-1).expand_as(advantages)
89+
)
7390
# Mask tokens that are both negative advantage AND in high-KL sequences
7491
should_mask = has_negative_advantage & exceeds_kl_threshold
7592
sequence_mask = (~should_mask).float()
7693

94+
# Apply sequence mask to the losses
95+
pg_losses = pg_losses * sequence_mask
96+
7797
metrics_seq_mask = {
7898
"seq_mask/masked_tokens": should_mask.float().sum().item()
7999
/ (action_mask.sum().item() + 1e-10),
80100
"seq_mask/mean_sequence_kl": sequence_kl.mean().detach().item(),
81101
}
82-
else:
83-
metrics_seq_mask = {}
84-
85-
pg_losses1 = -advantages * ratio
86-
pg_losses2 = -advantages * torch.clamp(
87-
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
88-
)
89-
90-
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
91-
92-
pg_clip_frac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask)
93-
94-
pg_losses3 = -advantages * self.clip_ratio_c
95-
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
96-
pg_clipfrac_lower = masked_mean(
97-
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask
98-
)
99-
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
100-
101-
# Apply sequence mask to the losses
102-
if self.enable_sequence_masking:
103-
pg_losses = pg_losses * sequence_mask
104102

105103
pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
106104
metrics = {
@@ -109,7 +107,8 @@ def __call__( # type: ignore
109107
"ppo_kl": ppo_kl.detach().item(),
110108
"pg_loss": pg_loss.detach().item(),
111109
}
112-
metrics.update(metrics_seq_mask)
110+
if self.enable_sequence_masking:
111+
metrics.update(metrics_seq_mask)
113112
return pg_loss, metrics
114113

115114
@classmethod
@@ -119,5 +118,5 @@ def default_args(cls) -> Dict:
119118
"clip_ratio_c": 3.0,
120119
"loss_agg_mode": "token-mean",
121120
"enable_sequence_masking": False,
122-
"delta": 0.1,
121+
"delta_sequence_masking": 0.1,
123122
}

0 commit comments

Comments
 (0)