Skip to content

Commit a171560

Browse files
authored
add fallback_to_policy_gradient option (#443)
1 parent 9693926 commit a171560

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
loss_agg_mode: Optional[str] = "token-mean",
2424
enable_sequence_masking: bool = False, # introduced in DeepseekV3.2
2525
delta_sequence_masking: float = 0.1,
26+
fallback_to_policy_gradient: bool = False,
2627
) -> None:
2728
super().__init__(backend=backend)
2829
if clip_range_low is None:
@@ -40,6 +41,7 @@ def __init__(
4041
self.loss_agg_mode = loss_agg_mode
4142
self.enable_sequence_masking = enable_sequence_masking
4243
self.delta_sequence_masking = delta_sequence_masking
44+
self.fallback_to_policy_gradient = fallback_to_policy_gradient
4345

4446
def __call__( # type: ignore
4547
self,
@@ -50,6 +52,9 @@ def __call__( # type: ignore
5052
**kwargs,
5153
) -> Tuple[torch.Tensor, Dict]:
5254
negative_approx_kl = logprob - old_logprob
55+
if self.fallback_to_policy_gradient:
56+
# ignore vllm logprob difference and use pure policy gradient loss
57+
negative_approx_kl = logprob - logprob.detach()
5358
# Clamp negative_approx_kl for stability
5459
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
5560
ratio = torch.exp(negative_approx_kl)
@@ -119,4 +124,5 @@ def default_args(cls) -> Dict:
119124
"loss_agg_mode": "token-mean",
120125
"enable_sequence_masking": False,
121126
"delta_sequence_masking": 0.1,
127+
"fallback_to_policy_gradient": False,
122128
}

0 commit comments

Comments
 (0)