generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
feat(experimental): Divergence Proximal Policy Optimization #5117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
LeonEricsson
wants to merge
21
commits into
huggingface:main
Choose a base branch
from
LeonEricsson:feature/dppo
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
8f5aabd
init dppo
LeonEricsson 11b59ad
cleanup dead paths
LeonEricsson 28c2ab0
exp dppo train script
LeonEricsson 0dc628e
topk logprobs without recompute
LeonEricsson 8078913
always return sampling token logps + transformer_paged not allowed
LeonEricsson b7d0029
tests + paper index + cleanup
LeonEricsson 9a93fa5
Merge branch 'main' into feature/dppo
LeonEricsson de2826d
Update trl/experimental/dppo/dppo_config.py
LeonEricsson 1b9e451
Update trl/experimental/dppo/dppo_trainer.py
LeonEricsson 7ea365c
Ensure sampled token is in top-K set for topk divergence
LeonEricsson 45ef405
Merge remote-tracking branch 'origin/main' into feature/dppo
LeonEricsson 1149358
Merge branch 'main' into feature/dppo
LeonEricsson f1f953b
Merge branch 'main' into feature/dppo
LeonEricsson 5ad2ab6
Merge branch 'main' into feature/dppo
LeonEricsson 2830148
Update trl/experimental/dppo/dppo_trainer.py
LeonEricsson fca394c
Update trl/experimental/dppo/dppo_trainer.py
LeonEricsson 8d2cb24
Update trl/experimental/dppo/dppo_trainer.py
LeonEricsson 8b774be
Update trl/experimental/dppo/dppo_trainer.py
LeonEricsson 899032a
fix: tests and missing init params
LeonEricsson 6fd5906
better defaults
LeonEricsson 3f3a0ec
clarify kl/tv defaults + add mm_token_type_ids throughout
LeonEricsson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,184 @@ | ||
| # Copyright 2020-2026 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import pytest | ||
| import torch | ||
| from datasets import load_dataset | ||
|
|
||
| from trl.experimental.dppo import DPPOConfig, DPPOTrainer | ||
|
|
||
| from ..testing_utils import TrlTestCase | ||
|
|
||
|
|
||
| class TestDPPODivergenceMask: | ||
| """Unit tests for _compute_divergence_mask with synthetic inputs.""" | ||
|
|
||
| @staticmethod | ||
| def make_trainer(divergence_type="binary_tv", epsilon=0.2, epsilon_high=0.28): | ||
| """Create a minimal DPPOTrainer-like object with just the attributes needed for _compute_divergence_mask.""" | ||
|
|
||
| class Stub: | ||
| pass | ||
|
|
||
| stub = Stub() | ||
| stub.divergence_type = divergence_type | ||
| stub.epsilon_low = epsilon | ||
| stub.epsilon_high = epsilon_high | ||
| return stub | ||
|
|
||
| @staticmethod | ||
| def compute_divergence_mask( | ||
| trainer_stub, | ||
| current_logps, | ||
| sampling_logps, | ||
| advantages, | ||
| completion_mask, | ||
| current_topk_logps=None, | ||
| sampling_topk_logps=None, | ||
| ): | ||
| return DPPOTrainer._compute_divergence_mask( | ||
| trainer_stub, | ||
| current_logps, | ||
| sampling_logps, | ||
| advantages, | ||
| completion_mask, | ||
| current_topk_logps=current_topk_logps, | ||
| sampling_topk_logps=sampling_topk_logps, | ||
| ) | ||
|
|
||
| def test_binary_tv_no_masking_within_threshold(self): | ||
| stub = self.make_trainer("binary_tv", epsilon=0.2, epsilon_high=0.28) | ||
| # Policies are very close — no tokens should be masked | ||
| sampling_logps = torch.log(torch.tensor([[0.5, 0.3, 0.7]])) | ||
| current_logps = torch.log(torch.tensor([[0.51, 0.29, 0.71]])) | ||
| advantages = torch.tensor([[1.0]]) | ||
| completion_mask = torch.ones(1, 3) | ||
|
|
||
| mask = self.compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask) | ||
| assert mask.shape == (1, 3) | ||
| assert (mask == 1.0).all() | ||
|
|
||
| def test_binary_tv_masks_positive_advantage_high_divergence(self): | ||
| stub = self.make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01) | ||
| # π much higher than μ, positive advantage → should be masked (invalid_pos) | ||
| sampling_logps = torch.log(torch.tensor([[0.1]])) | ||
| current_logps = torch.log(torch.tensor([[0.5]])) | ||
| advantages = torch.tensor([[1.0]]) | ||
| completion_mask = torch.ones(1, 1) | ||
|
|
||
| mask = self.compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask) | ||
| assert mask.item() == 0.0 | ||
|
|
||
| def test_binary_tv_masks_negative_advantage_low_divergence(self): | ||
| stub = self.make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01) | ||
| # π much lower than μ, negative advantage → should be masked (invalid_neg) | ||
| sampling_logps = torch.log(torch.tensor([[0.5]])) | ||
| current_logps = torch.log(torch.tensor([[0.1]])) | ||
| advantages = torch.tensor([[-1.0]]) | ||
| completion_mask = torch.ones(1, 1) | ||
|
|
||
| mask = self.compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask) | ||
| assert mask.item() == 0.0 | ||
|
|
||
| def test_binary_tv_respects_completion_mask(self): | ||
| stub = self.make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01) | ||
| # Even though divergence is huge, padding tokens stay 0 | ||
| sampling_logps = torch.log(torch.tensor([[0.1, 0.5]])) | ||
| current_logps = torch.log(torch.tensor([[0.9, 0.9]])) | ||
| advantages = torch.tensor([[1.0]]) | ||
| completion_mask = torch.tensor([[1.0, 0.0]]) | ||
|
|
||
| mask = self.compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask) | ||
| assert mask[0, 1].item() == 0.0 | ||
|
|
||
| def test_topk_tv_requires_topk_inputs(self): | ||
| stub = self.make_trainer("topk_tv") | ||
| B, T, K = 1, 2, 4 | ||
| sampling_logps = torch.log(torch.full((B, T), 0.3)) | ||
| current_logps = torch.log(torch.full((B, T), 0.31)) | ||
| advantages = torch.tensor([[1.0]]) | ||
| completion_mask = torch.ones(B, T) | ||
|
|
||
| # Build top-K distributions that are nearly identical | ||
| topk_probs = torch.softmax(torch.randn(B, T, K), dim=-1) | ||
| sampling_topk_logps = torch.log(topk_probs) | ||
| current_topk_logps = torch.log(topk_probs + 0.001) | ||
|
|
||
| mask = self.compute_divergence_mask( | ||
| stub, | ||
| current_logps, | ||
| sampling_logps, | ||
| advantages, | ||
| completion_mask, | ||
| current_topk_logps=current_topk_logps, | ||
| sampling_topk_logps=sampling_topk_logps, | ||
| ) | ||
| assert mask.shape == (B, T) | ||
| assert (mask == 1.0).all() | ||
|
|
||
|
|
||
| @pytest.mark.low_priority | ||
| class TestDPPOTrainer(TrlTestCase): | ||
| @pytest.mark.parametrize("divergence_type", ["binary_tv", "binary_kl"]) | ||
| def test_training_binary(self, divergence_type): | ||
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | ||
|
|
||
| training_args = DPPOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, | ||
| per_device_train_batch_size=3, | ||
| num_generations=3, | ||
| max_completion_length=8, | ||
| divergence_type=divergence_type, | ||
| report_to="none", | ||
| ) | ||
| trainer = DPPOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
||
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | ||
|
|
||
| trainer.train() | ||
|
|
||
| assert trainer.state.log_history[-1]["train_loss"] is not None | ||
|
|
||
| for n, param in previous_trainable_params.items(): | ||
| new_param = trainer.model.get_parameter(n) | ||
| assert not torch.equal(param, new_param), f"Parameter {n} has not changed." | ||
|
|
||
| @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) | ||
| def test_training_conversational(self, config_name): | ||
| dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") | ||
|
|
||
| training_args = DPPOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, | ||
| per_device_train_batch_size=3, | ||
| num_generations=3, | ||
| max_completion_length=8, | ||
| report_to="none", | ||
| ) | ||
| trainer = DPPOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
||
| trainer.train() | ||
|
|
||
| assert trainer.state.log_history[-1]["train_loss"] is not None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # Copyright 2020-2026 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| from .dppo_config import DPPOConfig | ||
| from .dppo_trainer import DPPOTrainer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| # Copyright 2020-2026 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from typing import Literal | ||
|
|
||
| from ...trainer.grpo_config import GRPOConfig | ||
|
|
||
|
|
||
| @dataclass | ||
| class DPPOConfig(GRPOConfig): | ||
| """ | ||
| Configuration class for DPPOTrainer. | ||
|
|
||
| DPPO (Divergence Proximal Policy Optimization) replaces PPO/GRPO's heuristic ratio-clipping with a principled | ||
| trust region based on direct policy divergence estimates. | ||
|
|
||
| Paper: "Rethinking the Trust Region in LLM Reinforcement Learning" (arXiv:2602.04879) | ||
|
|
||
| Args: | ||
| divergence_type (`Literal["binary_tv", "binary_kl", "topk_tv", "topk_kl"]`, *optional*, defaults to `"binary_tv"`): | ||
| Divergence approximation used for the trust-region mask. Binary variants use only per-token log-probs; | ||
| top-K variants require storing top-K token IDs and log-probs during rollout generation plus full logits | ||
| during training. | ||
|
|
||
| divergence_topk (`int`, *optional*, defaults to `20`): | ||
| K for top-K divergence approximations. Only used when `divergence_type` is `"topk_tv"` or `"topk_kl"`. | ||
|
|
||
| clip_ratio_c (`float`, *optional*, defaults to `20.0`): | ||
| Upper bound on the importance-sampling ratio for stability. The IS ratio is clamped to [0, clip_ratio_c]. | ||
|
|
||
| epsilon (`float`, inherited from GRPOConfig, default overridden to `0.15`): | ||
| Divergence threshold δ_low. Tokens whose divergence exceeds this when the policy moves in the | ||
| advantage-decreasing direction are masked. The paper recommends 0.15 for TV divergence | ||
| and 0.05 for KL divergence. | ||
|
|
||
| epsilon_high (`float`, inherited from GRPOConfig, default overridden to `0.15`): | ||
| Divergence threshold δ_high. Tokens whose divergence exceeds this when the policy moves in the | ||
| advantage-increasing direction are masked. The paper recommends 0.15 for TV divergence | ||
| and 0.05 for KL divergence. | ||
| """ | ||
|
|
||
| divergence_type: Literal["binary_tv", "binary_kl", "topk_tv", "topk_kl"] = field( | ||
| default="binary_tv", | ||
| metadata={ | ||
| "help": "Divergence approximation used for the trust-region mask. Binary variants use only per-token " | ||
| "log-probs; top-K variants require storing top-K token IDs and log-probs during rollout generation plus " | ||
| "full logits during training." | ||
| }, | ||
| ) | ||
| divergence_topk: int = field( | ||
| default=20, | ||
| metadata={ | ||
| "help": "K for top-K divergence approximations. Only used when `divergence_type` is `'topk_tv'` or " | ||
| "`'topk_kl'`." | ||
| }, | ||
| ) | ||
| clip_ratio_c: float = field( | ||
| default=20.0, | ||
| metadata={ | ||
| "help": "Upper bound on the importance-sampling ratio for stability. The IS ratio is clamped to " | ||
| "[0, clip_ratio_c]." | ||
| }, | ||
| ) | ||
| epsilon: float = field( | ||
| default=0.15, | ||
| metadata={ | ||
| "help": "Divergence threshold δ_low. Tokens whose divergence exceeds this when the policy moves in the " | ||
| "advantage-decreasing direction are masked. The paper recommends 0.15 for TV divergence and 0.05 for KL " | ||
| "divergence." | ||
| }, | ||
| ) | ||
| epsilon_high: float = field( | ||
| default=0.15, | ||
| metadata={ | ||
| "help": "Divergence threshold δ_high. Tokens whose divergence exceeds this when the policy moves in the " | ||
| "advantage-increasing direction are masked. The paper recommends 0.15 for TV divergence and 0.05 for KL " | ||
| "divergence." | ||
| }, | ||
| ) | ||
|
|
||
| def __post_init__(self): | ||
| super().__post_init__() | ||
|
|
||
| if self.divergence_type not in ("binary_tv", "binary_kl", "topk_tv", "topk_kl"): | ||
| raise ValueError( | ||
| f"divergence_type must be one of 'binary_tv', 'binary_kl', 'topk_tv', 'topk_kl', " | ||
| f"got {self.divergence_type!r}" | ||
| ) | ||
|
|
||
| if self.divergence_topk < 1: | ||
| raise ValueError(f"divergence_topk must be >= 1, got {self.divergence_topk}") | ||
|
|
||
| if self.clip_ratio_c <= 0: | ||
| raise ValueError(f"clip_ratio_c must be > 0, got {self.clip_ratio_c}") | ||
|
|
||
| if self.loss_type != "dapo": | ||
| raise ValueError("loss_type {self.loss_type} is not supported for DPPO") | ||
|
|
||
| if self.top_entropy_quantile != 1.0: | ||
| raise ValueError("top_entropy_quantile is not supported for DPPO") | ||
|
|
||
| if self.off_policy_mask_threshold is not None: | ||
| raise ValueError("off_policy_mask_threshold is not supported for DPPO") | ||
|
|
||
| if self.use_transformers_paged: | ||
| raise ValueError( | ||
| "DPPO requires sampled token logprobs from the generation backend. " | ||
| "Transformers paged (`use_transformers_paged=True`) does not support logprob extraction." | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.