Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
loss_agg_mode: Optional[str] = "token-mean",
enable_sequence_masking: bool = False, # introduced in DeepseekV3.2
delta_sequence_masking: float = 0.1,
fallback_to_policy_gradient: bool = False,
) -> None:
super().__init__(backend=backend)
if clip_range_low is None:
Expand All @@ -40,6 +41,7 @@ def __init__(
self.loss_agg_mode = loss_agg_mode
self.enable_sequence_masking = enable_sequence_masking
self.delta_sequence_masking = delta_sequence_masking
self.fallback_to_policy_gradient = fallback_to_policy_gradient

def __call__( # type: ignore
self,
Expand All @@ -50,6 +52,11 @@ def __call__( # type: ignore
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
negative_approx_kl = logprob - old_logprob
if self.fallback_to_policy_gradient:
print('------')
print('logprob shape:', logprob.shape)
print('------')
negative_approx_kl = logprob - logprob.detach()
# Clamp negative_approx_kl for stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
Expand Down Expand Up @@ -119,4 +126,5 @@ def default_args(cls) -> Dict:
"loss_agg_mode": "token-mean",
"enable_sequence_masking": False,
"delta_sequence_masking": 0.1,
"fallback_to_policy_gradient": False,
}
Loading