@@ -72,11 +72,10 @@ class DPPOTrainer(GRPOTrainer):
7272 """
7373 Trainer for Divergence Proximal Policy Optimization (DPPO).
7474
75- DPPO replaces PPO/GRPO's heuristic ratio-clipping with a principled trust region based on direct policy
76- divergence estimates. PPO-style clipping masks tokens based on probability ratio π/μ, which over-penalizes
77- low-probability tokens and under-penalizes high-probability tokens. In contrast, DPPO masks based on
78- direct approximation of policy divergence (e.g TV or KL) ensuring updates stay within a theoretically
79- grounded trust region.
75+ DPPO replaces PPO/GRPO's heuristic ratio-clipping with a principled trust region based on direct policy divergence
76+ estimates. PPO-style clipping masks tokens based on probability ratio π/μ, which over-penalizes low-probability
77+ tokens and under-penalizes high-probability tokens. In contrast, DPPO masks based on direct approximation of policy
78+ divergence (e.g TV or KL) ensuring updates stay within a theoretically grounded trust region.
8079
8180
8281 Four divergence approximations are supported:
@@ -275,8 +274,8 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
275274 """Generate completions, always extracting sampled token logprobs.
276275
277276 Returns:
278- 5-tuple of (prompt_ids, completion_ids, logprobs, topk_logprobs, topk_token_ids).
279- topk_logprobs and topk_token_ids are None when divergence_type is not topk.
277+ 5-tuple of (prompt_ids, completion_ids, logprobs, topk_logprobs, topk_token_ids). topk_logprobs and
278+ topk_token_ids are None when divergence_type is not topk.
280279 """
281280 device = self .accelerator .device
282281 mode = "train" if self .model .training else "eval"
@@ -420,9 +419,9 @@ def _tool_call_loop(
420419 ):
421420 """Tool execution loop that also threads top-K logprob data alongside logprobs.
422421
423- Mirrors GRPOTrainer._tool_call_loop but additionally concatenates topk_logprobs and topk_token_ids
424- the same way logprobs is concatenated: real data for model-generated tokens, zero-padding for
425- tool-result tokens. When topk data is None (binary divergence), behaves identically to the parent.
422+ Mirrors GRPOTrainer._tool_call_loop but additionally concatenates topk_logprobs and topk_token_ids the same way
423+ logprobs is concatenated: real data for model-generated tokens, zero-padding for tool-result tokens. When topk
424+ data is None (binary divergence), behaves identically to the parent.
426425 """
427426 K = self .divergence_topk
428427 has_topk = topk_logprobs is not None
@@ -620,8 +619,8 @@ def _generate(self, prompts: list):
620619 """Generate completions, handling tool calls, and thread top-K logprob data through the full pipeline.
621620
622621 Returns:
623- 9-tuple of (prompt_ids, completion_ids, tool_mask, completions, total_completion_tokens,
624- logprobs, topk_logprobs, topk_token_ids, extra_fields).
622+ 9-tuple of (prompt_ids, completion_ids, tool_mask, completions, total_completion_tokens, logprobs,
623+ topk_logprobs, topk_token_ids, extra_fields).
625624 """
626625 device = self .accelerator .device
627626 mode = "train" if self .model .training else "eval"
@@ -768,8 +767,8 @@ def _get_per_token_logps_with_topk(
768767 ) -> tuple [torch .Tensor , torch .Tensor | None , torch .Tensor ]:
769768 """Compute per-token log-probs, (optionally) entropies, and top-K log-probs in one forward pass.
770769
771- Evaluates the current policy's log-probs at the rollout's top-K token IDs from the same
772- forward pass used for per_token_logps, avoiding an extra model call.
770+ Evaluates the current policy's log-probs at the rollout's top-K token IDs from the same forward pass used for
771+ per_token_logps, avoiding an extra model call.
773772
774773 Args:
775774 topk_token_ids: Rollout policy's top-K token IDs, shape (B, T, K). The current policy's
@@ -1207,11 +1206,11 @@ def _compute_divergence_mask(
12071206 completion_mask (`torch.Tensor`):
12081207 Binary mask of shape `(B, T)` where `1` indicates valid completion tokens and `0` padding.
12091208 current_topk_logps (`torch.Tensor` or `None`):
1210- Log-probabilities of the current policy at the rollout's top-K token IDs, shape `(B, T, K)`.
1211- Required when `divergence_type` is `"topk_tv"` or `"topk_kl"`.
1209+ Log-probabilities of the current policy at the rollout's top-K token IDs, shape `(B, T, K)`. Required
1210+ when `divergence_type` is `"topk_tv"` or `"topk_kl"`.
12121211 sampling_topk_logps (`torch.Tensor` or `None`):
1213- Log-probabilities of the sampling policy at the rollout's top-K token IDs, shape `(B, T, K)`.
1214- Required when `divergence_type` is `"topk_tv"` or `"topk_kl"`.
1212+ Log-probabilities of the sampling policy at the rollout's top-K token IDs, shape `(B, T, K)`. Required
1213+ when `divergence_type` is `"topk_tv"` or `"topk_kl"`.
12151214
12161215 Returns:
12171216 `torch.Tensor`:
0 commit comments