Skip to content

Commit 4fea6d1

Browse files
Fix style in DPPO docstrings (#5326)
1 parent e8dcece commit 4fea6d1

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

trl/experimental/dppo/dppo_config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class DPPOConfig(GRPOConfig):
2323
"""
2424
Configuration class for DPPOTrainer.
2525
26-
DPPO (Divergence Proximal Policy Optimization) replaces PPO/GRPO's heuristic ratio-clipping with a principled
27-
trust region based on direct policy divergence estimates.
26+
DPPO (Divergence Proximal Policy Optimization) replaces PPO/GRPO's heuristic ratio-clipping with a principled trust
27+
region based on direct policy divergence estimates.
2828
2929
Paper: "Rethinking the Trust Region in LLM Reinforcement Learning" (arXiv:2602.04879)
3030
@@ -42,13 +42,13 @@ class DPPOConfig(GRPOConfig):
4242
4343
epsilon (`float`, inherited from GRPOConfig, default overridden to `0.15`):
4444
Divergence threshold δ_low. Tokens whose divergence exceeds this when the policy moves in the
45-
advantage-decreasing direction are masked. The paper recommends 0.15 for TV divergence
46-
and 0.05 for KL divergence.
45+
advantage-decreasing direction are masked. The paper recommends 0.15 for TV divergence and 0.05 for KL
46+
divergence.
4747
4848
epsilon_high (`float`, inherited from GRPOConfig, default overridden to `0.15`):
4949
Divergence threshold δ_high. Tokens whose divergence exceeds this when the policy moves in the
50-
advantage-increasing direction are masked. The paper recommends 0.15 for TV divergence
51-
and 0.05 for KL divergence.
50+
advantage-increasing direction are masked. The paper recommends 0.15 for TV divergence and 0.05 for KL
51+
divergence.
5252
"""
5353

5454
divergence_type: Literal["binary_tv", "binary_kl", "topk_tv", "topk_kl"] = field(

trl/experimental/dppo/dppo_trainer.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,10 @@ class DPPOTrainer(GRPOTrainer):
7272
"""
7373
Trainer for Divergence Proximal Policy Optimization (DPPO).
7474
75-
DPPO replaces PPO/GRPO's heuristic ratio-clipping with a principled trust region based on direct policy
76-
divergence estimates. PPO-style clipping masks tokens based on probability ratio π/μ, which over-penalizes
77-
low-probability tokens and under-penalizes high-probability tokens. In contrast, DPPO masks based on
78-
direct approximation of policy divergence (e.g TV or KL) ensuring updates stay within a theoretically
79-
grounded trust region.
75+
DPPO replaces PPO/GRPO's heuristic ratio-clipping with a principled trust region based on direct policy divergence
76+
estimates. PPO-style clipping masks tokens based on probability ratio π/μ, which over-penalizes low-probability
77+
tokens and under-penalizes high-probability tokens. In contrast, DPPO masks based on direct approximation of policy
78+
divergence (e.g TV or KL) ensuring updates stay within a theoretically grounded trust region.
8079
8180
8281
Four divergence approximations are supported:
@@ -275,8 +274,8 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
275274
"""Generate completions, always extracting sampled token logprobs.
276275
277276
Returns:
278-
5-tuple of (prompt_ids, completion_ids, logprobs, topk_logprobs, topk_token_ids).
279-
topk_logprobs and topk_token_ids are None when divergence_type is not topk.
277+
5-tuple of (prompt_ids, completion_ids, logprobs, topk_logprobs, topk_token_ids). topk_logprobs and
278+
topk_token_ids are None when divergence_type is not topk.
280279
"""
281280
device = self.accelerator.device
282281
mode = "train" if self.model.training else "eval"
@@ -420,9 +419,9 @@ def _tool_call_loop(
420419
):
421420
"""Tool execution loop that also threads top-K logprob data alongside logprobs.
422421
423-
Mirrors GRPOTrainer._tool_call_loop but additionally concatenates topk_logprobs and topk_token_ids
424-
the same way logprobs is concatenated: real data for model-generated tokens, zero-padding for
425-
tool-result tokens. When topk data is None (binary divergence), behaves identically to the parent.
422+
Mirrors GRPOTrainer._tool_call_loop but additionally concatenates topk_logprobs and topk_token_ids the same way
423+
logprobs is concatenated: real data for model-generated tokens, zero-padding for tool-result tokens. When topk
424+
data is None (binary divergence), behaves identically to the parent.
426425
"""
427426
K = self.divergence_topk
428427
has_topk = topk_logprobs is not None
@@ -620,8 +619,8 @@ def _generate(self, prompts: list):
620619
"""Generate completions, handling tool calls, and thread top-K logprob data through the full pipeline.
621620
622621
Returns:
623-
9-tuple of (prompt_ids, completion_ids, tool_mask, completions, total_completion_tokens,
624-
logprobs, topk_logprobs, topk_token_ids, extra_fields).
622+
9-tuple of (prompt_ids, completion_ids, tool_mask, completions, total_completion_tokens, logprobs,
623+
topk_logprobs, topk_token_ids, extra_fields).
625624
"""
626625
device = self.accelerator.device
627626
mode = "train" if self.model.training else "eval"
@@ -768,8 +767,8 @@ def _get_per_token_logps_with_topk(
768767
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
769768
"""Compute per-token log-probs, (optionally) entropies, and top-K log-probs in one forward pass.
770769
771-
Evaluates the current policy's log-probs at the rollout's top-K token IDs from the same
772-
forward pass used for per_token_logps, avoiding an extra model call.
770+
Evaluates the current policy's log-probs at the rollout's top-K token IDs from the same forward pass used for
771+
per_token_logps, avoiding an extra model call.
773772
774773
Args:
775774
topk_token_ids: Rollout policy's top-K token IDs, shape (B, T, K). The current policy's
@@ -1207,11 +1206,11 @@ def _compute_divergence_mask(
12071206
completion_mask (`torch.Tensor`):
12081207
Binary mask of shape `(B, T)` where `1` indicates valid completion tokens and `0` padding.
12091208
current_topk_logps (`torch.Tensor` or `None`):
1210-
Log-probabilities of the current policy at the rollout's top-K token IDs, shape `(B, T, K)`.
1211-
Required when `divergence_type` is `"topk_tv"` or `"topk_kl"`.
1209+
Log-probabilities of the current policy at the rollout's top-K token IDs, shape `(B, T, K)`. Required
1210+
when `divergence_type` is `"topk_tv"` or `"topk_kl"`.
12121211
sampling_topk_logps (`torch.Tensor` or `None`):
1213-
Log-probabilities of the sampling policy at the rollout's top-K token IDs, shape `(B, T, K)`.
1214-
Required when `divergence_type` is `"topk_tv"` or `"topk_kl"`.
1212+
Log-probabilities of the sampling policy at the rollout's top-K token IDs, shape `(B, T, K)`. Required
1213+
when `divergence_type` is `"topk_tv"` or `"topk_kl"`.
12151214
12161215
Returns:
12171216
`torch.Tensor`:

0 commit comments

Comments
 (0)