Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
8f5aabd
init dppo
LeonEricsson Feb 13, 2026
11b59ad
cleanup dead paths
LeonEricsson Feb 13, 2026
28c2ab0
exp dppo train script
LeonEricsson Feb 15, 2026
0dc628e
topk logprobs without recompute
LeonEricsson Feb 17, 2026
8078913
always return sampling token logps + transformer_paged not allowed
LeonEricsson Feb 17, 2026
b7d0029
tests + paper index + cleanup
LeonEricsson Feb 17, 2026
9a93fa5
Merge branch 'main' into feature/dppo
LeonEricsson Feb 17, 2026
de2826d
Update trl/experimental/dppo/dppo_config.py
LeonEricsson Feb 18, 2026
1b9e451
Update trl/experimental/dppo/dppo_trainer.py
LeonEricsson Feb 22, 2026
7ea365c
Ensure sampled token is in top-K set for topk divergence
LeonEricsson Feb 22, 2026
45ef405
Merge remote-tracking branch 'origin/main' into feature/dppo
LeonEricsson Feb 25, 2026
1149358
Merge branch 'main' into feature/dppo
LeonEricsson Feb 25, 2026
f1f953b
Merge branch 'main' into feature/dppo
LeonEricsson Feb 25, 2026
5ad2ab6
Merge branch 'main' into feature/dppo
LeonEricsson Feb 26, 2026
2830148
Update trl/experimental/dppo/dppo_trainer.py
LeonEricsson Feb 27, 2026
fca394c
Update trl/experimental/dppo/dppo_trainer.py
LeonEricsson Feb 27, 2026
8d2cb24
Update trl/experimental/dppo/dppo_trainer.py
LeonEricsson Feb 27, 2026
8b774be
Update trl/experimental/dppo/dppo_trainer.py
LeonEricsson Feb 27, 2026
899032a
fix: tests and missing init params
LeonEricsson Feb 27, 2026
6fd5906
better defaults
LeonEricsson Mar 1, 2026
3f3a0ec
clarify kl/tv defaults + add mm_token_type_ids throughout
LeonEricsson Mar 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,37 @@ training_args = GRPOConfig(
)
```


### Rethinking the Trust Region in LLM Reinforcement Learning

**📜 Paper**: https://huggingface.co/papers/2602.04879

DPPO replaces PPO/GRPO's heuristic ratio-clipping with a principled trust region based on direct policy divergence estimates. PPO-style clipping masks tokens based on the probability ratio π/μ, which over-penalizes low-probability tokens and under-penalizes high-probability ones. DPPO instead masks based on direct approximations of policy divergence (TV or KL), ensuring updates stay within a theoretically grounded trust region. Four divergence approximations are supported: `binary_tv`, `binary_kl`, `topk_tv`, and `topk_kl`.

```python
from trl.experimental.dppo import DPPOConfig, DPPOTrainer

training_args = DPPOConfig(
divergence_type="binary_kl", # divergence approximation (Section 3.2 of the paper)
divergence_topk=20, # K for top-K divergence modes (Section 3.2 of the paper)
epsilon=0.2, # δ_low threshold (Section 3.2 of the paper)
epsilon_high=0.28, # δ_high threshold (Section 3.2 of the paper)
clip_ratio_c=3.0, # IS ratio upper bound C (Section 3.2 of the paper)
beta=0.0, # KL regularization coefficient
use_vllm=True,
)

trainer = DPPOTrainer(
model="your-model",
reward_funcs=[...],
args=training_args,
train_dataset=dataset,
)
trainer.train()
```

The official code [sail-sg/Stable-RL](https://github.com/sail-sg/Stable-RL)

## Direct Policy Optimization

Papers relating to the [`DPOTrainer`]
Expand Down
188 changes: 188 additions & 0 deletions tests/experimental/test_dppo_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from datasets import load_dataset

from trl.experimental.dppo import DPPOConfig, DPPOTrainer

from ..testing_utils import TrlTestCase


class TestDPPODivergenceMask:
"""Unit tests for _compute_divergence_mask with synthetic inputs."""

def _make_trainer(self, divergence_type="binary_tv", epsilon=0.2, epsilon_high=0.28):
"""Create a minimal DPPOTrainer-like object with just the attributes needed for _compute_divergence_mask."""

class Stub:
pass

stub = Stub()
stub.divergence_type = divergence_type
stub.epsilon_low = epsilon
stub.epsilon_high = epsilon_high
return stub
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer using a public static method instead

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of DPPOTrainer._compute_divergence_mask

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you have a look at the proposed fix, wasn't quite sure I understood what you meant


def test_binary_tv_no_masking_within_threshold(self):
stub = self._make_trainer("binary_tv", epsilon=0.2, epsilon_high=0.28)
# Policies are very close — no tokens should be masked
sampling_logps = torch.log(torch.tensor([[0.5, 0.3, 0.7]]))
current_logps = torch.log(torch.tensor([[0.51, 0.29, 0.71]]))
advantages = torch.tensor([[1.0]])
completion_mask = torch.ones(1, 3)

mask = DPPOTrainer._compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask)
assert mask.shape == (1, 3)
assert (mask == 1.0).all()

def test_binary_tv_masks_positive_advantage_high_divergence(self):
stub = self._make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01)
# π much higher than μ, positive advantage → should be masked (invalid_pos)
sampling_logps = torch.log(torch.tensor([[0.1]]))
current_logps = torch.log(torch.tensor([[0.5]]))
advantages = torch.tensor([[1.0]])
completion_mask = torch.ones(1, 1)

mask = DPPOTrainer._compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask)
assert mask.item() == 0.0

def test_binary_tv_masks_negative_advantage_low_divergence(self):
stub = self._make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01)
# π much lower than μ, negative advantage → should be masked (invalid_neg)
sampling_logps = torch.log(torch.tensor([[0.5]]))
current_logps = torch.log(torch.tensor([[0.1]]))
advantages = torch.tensor([[-1.0]])
completion_mask = torch.ones(1, 1)

mask = DPPOTrainer._compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask)
assert mask.item() == 0.0

def test_binary_tv_respects_completion_mask(self):
stub = self._make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01)
# Even though divergence is huge, padding tokens stay 0
sampling_logps = torch.log(torch.tensor([[0.1, 0.5]]))
current_logps = torch.log(torch.tensor([[0.9, 0.9]]))
advantages = torch.tensor([[1.0]])
completion_mask = torch.tensor([[1.0, 0.0]])

mask = DPPOTrainer._compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask)
assert mask[0, 1].item() == 0.0

def test_topk_tv_requires_topk_inputs(self):
stub = self._make_trainer("topk_tv")
B, T, K = 1, 2, 4
sampling_logps = torch.log(torch.full((B, T), 0.3))
current_logps = torch.log(torch.full((B, T), 0.31))
advantages = torch.tensor([[1.0]])
completion_mask = torch.ones(B, T)

# Build top-K distributions that are nearly identical
topk_probs = torch.softmax(torch.randn(B, T, K), dim=-1)
sampling_topk_logps = torch.log(topk_probs)
current_topk_logps = torch.log(topk_probs + 0.001)

mask = DPPOTrainer._compute_divergence_mask(
stub,
current_logps,
sampling_logps,
advantages,
completion_mask,
current_topk_logps=current_topk_logps,
sampling_topk_logps=sampling_topk_logps,
)
assert mask.shape == (B, T)
assert (mask == 1.0).all()


@pytest.mark.low_priority
class TestDPPOTrainer(TrlTestCase):
@pytest.mark.parametrize("divergence_type", ["binary_tv", "binary_kl"])
def test_training_binary(self, divergence_type):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

config = DPPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
divergence_type=divergence_type,
report_to="none",
)
trainer = DPPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=config,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_with_custom_reward_func(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def dummy_reward(completions, **kwargs):
return [float(len(c)) for c in completions]

config = DPPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
report_to="none",
)
trainer = DPPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=dummy_reward,
args=config,
train_dataset=dataset,
)

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

@pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"])
def test_training_conversational(self, config_name):
dataset = load_dataset("trl-internal-testing/zen", config_name, split="train")

config = DPPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
report_to="none",
)
trainer = DPPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=config,
train_dataset=dataset,
)

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
17 changes: 17 additions & 0 deletions trl/experimental/dppo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .dppo_config import DPPOConfig
from .dppo_trainer import DPPOTrainer
107 changes: 107 additions & 0 deletions trl/experimental/dppo/dppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Literal

from ...trainer.grpo_config import GRPOConfig


@dataclass
class DPPOConfig(GRPOConfig):
"""
Configuration class for DPPOTrainer.

DPPO (Divergence Proximal Policy Optimization) replaces PPO/GRPO's heuristic ratio-clipping with a principled
trust region based on direct policy divergence estimates.

Paper: "Rethinking the Trust Region in LLM Reinforcement Learning" (arXiv:2602.04879)

Args:
divergence_type (`Literal["binary_tv", "binary_kl", "topk_tv", "topk_kl"]`, *optional*, defaults to `"binary_tv"`):
Divergence approximation used for the trust-region mask. Binary variants use only per-token log-probs;
top-K variants require storing top-K token IDs and log-probs during rollout generation plus full logits
during training.

divergence_topk (`int`, *optional*, defaults to `20`):
K for top-K divergence approximations. Only used when `divergence_type` starts with `"topk_"`.

clip_ratio_c (`float`, *optional*, defaults to `10.0`):
Upper bound on the importance-sampling ratio for stability. The IS ratio is clamped to [0, clip_ratio_c].

epsilon (`float`, inherited from GRPOConfig, default overridden to `0.2`):
Divergence threshold δ_low. Tokens whose divergence exceeds this when the policy moves in the
advantage-decreasing direction are masked.

epsilon_high (`float`, inherited from GRPOConfig, default overridden to `0.28`):
Divergence threshold δ_high. Tokens whose divergence exceeds this when the policy moves in the
advantage-increasing direction are masked. The paper recommends asymmetric thresholds.
"""

divergence_type: Literal["binary_tv", "binary_kl", "topk_tv", "topk_kl"] = field(
default="binary_tv",
metadata={
"help": "Divergence approximation for the trust-region mask. 'binary_tv': absolute probability "
"difference. 'binary_kl': Bernoulli KL divergence. 'topk_tv': TV over top-K tokens. "
"'topk_kl': KL over top-K tokens."
},
)
divergence_topk: int = field(
default=20,
metadata={
"help": "K for top-K divergence approximations. Only used when divergence_type starts with 'topk_'."
},
)
clip_ratio_c: float = field(
default=10.0,
metadata={"help": "Upper bound on the importance-sampling ratio for stability."},
)
epsilon: float = field(
default=0.2,
metadata={"help": "Divergence threshold δ_low for the trust-region mask."},
)
epsilon_high: float = field(
default=0.28,
metadata={"help": "Divergence threshold δ_high for the trust-region mask (asymmetric)."},
)

def __post_init__(self):
super().__post_init__()

if self.divergence_type not in ("binary_tv", "binary_kl", "topk_tv", "topk_kl"):
raise ValueError(
f"divergence_type must be one of 'binary_tv', 'binary_kl', 'topk_tv', 'topk_kl', "
f"got {self.divergence_type!r}"
)

if self.divergence_topk < 1:
raise ValueError(f"divergence_topk must be >= 1, got {self.divergence_topk}")

if self.clip_ratio_c <= 0:
raise ValueError(f"clip_ratio_c must be > 0, got {self.clip_ratio_c}")

if self.loss_type != "dapo":
raise ValueError("loss_type {self.loss_type} is not supported for DPPO")

if self.top_entropy_quantile != 1.0:
raise ValueError("top_entropy_quantile is not supported for DPPO")

if self.off_policy_mask_threshold is not None:
raise ValueError("off_policy_mask_threshold is not supported for DPPO")

if self.use_transformers_paged:
raise ValueError(
"DPPO requires sampled token logprobs from the generation backend. "
"Transformers paged (`use_transformers_paged=True`) does not support logprob extraction."
)
Loading
Loading