-
Notifications
You must be signed in to change notification settings - Fork 2.5k
feat(experimental): Divergence Proximal Policy Optimization #5117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
8f5aabd
11b59ad
28c2ab0
0dc628e
8078913
b7d0029
9a93fa5
de2826d
1b9e451
7ea365c
45ef405
1149358
f1f953b
5ad2ab6
2830148
fca394c
8d2cb24
8b774be
899032a
6fd5906
3f3a0ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| # Copyright 2020-2026 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import pytest | ||
| import torch | ||
| from datasets import load_dataset | ||
|
|
||
| from trl.experimental.dppo import DPPOConfig, DPPOTrainer | ||
|
|
||
| from ..testing_utils import TrlTestCase | ||
|
|
||
|
|
||
| class TestDPPODivergenceMask: | ||
| """Unit tests for _compute_divergence_mask with synthetic inputs.""" | ||
|
|
||
| def _make_trainer(self, divergence_type="binary_tv", epsilon=0.2, epsilon_high=0.28): | ||
| """Create a minimal DPPOTrainer-like object with just the attributes needed for _compute_divergence_mask.""" | ||
|
|
||
| class Stub: | ||
| pass | ||
|
|
||
| stub = Stub() | ||
| stub.divergence_type = divergence_type | ||
| stub.epsilon_low = epsilon | ||
| stub.epsilon_high = epsilon_high | ||
| return stub | ||
|
||
|
|
||
| def test_binary_tv_no_masking_within_threshold(self): | ||
| stub = self._make_trainer("binary_tv", epsilon=0.2, epsilon_high=0.28) | ||
| # Policies are very close — no tokens should be masked | ||
| sampling_logps = torch.log(torch.tensor([[0.5, 0.3, 0.7]])) | ||
| current_logps = torch.log(torch.tensor([[0.51, 0.29, 0.71]])) | ||
| advantages = torch.tensor([[1.0]]) | ||
| completion_mask = torch.ones(1, 3) | ||
|
|
||
| mask = DPPOTrainer._compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask) | ||
| assert mask.shape == (1, 3) | ||
| assert (mask == 1.0).all() | ||
|
|
||
| def test_binary_tv_masks_positive_advantage_high_divergence(self): | ||
| stub = self._make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01) | ||
| # π much higher than μ, positive advantage → should be masked (invalid_pos) | ||
| sampling_logps = torch.log(torch.tensor([[0.1]])) | ||
| current_logps = torch.log(torch.tensor([[0.5]])) | ||
| advantages = torch.tensor([[1.0]]) | ||
| completion_mask = torch.ones(1, 1) | ||
|
|
||
| mask = DPPOTrainer._compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask) | ||
| assert mask.item() == 0.0 | ||
|
|
||
| def test_binary_tv_masks_negative_advantage_low_divergence(self): | ||
| stub = self._make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01) | ||
| # π much lower than μ, negative advantage → should be masked (invalid_neg) | ||
| sampling_logps = torch.log(torch.tensor([[0.5]])) | ||
| current_logps = torch.log(torch.tensor([[0.1]])) | ||
| advantages = torch.tensor([[-1.0]]) | ||
| completion_mask = torch.ones(1, 1) | ||
|
|
||
| mask = DPPOTrainer._compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask) | ||
| assert mask.item() == 0.0 | ||
|
|
||
| def test_binary_tv_respects_completion_mask(self): | ||
| stub = self._make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01) | ||
| # Even though divergence is huge, padding tokens stay 0 | ||
| sampling_logps = torch.log(torch.tensor([[0.1, 0.5]])) | ||
| current_logps = torch.log(torch.tensor([[0.9, 0.9]])) | ||
| advantages = torch.tensor([[1.0]]) | ||
| completion_mask = torch.tensor([[1.0, 0.0]]) | ||
|
|
||
| mask = DPPOTrainer._compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask) | ||
| assert mask[0, 1].item() == 0.0 | ||
|
|
||
| def test_topk_tv_requires_topk_inputs(self): | ||
| stub = self._make_trainer("topk_tv") | ||
| B, T, K = 1, 2, 4 | ||
| sampling_logps = torch.log(torch.full((B, T), 0.3)) | ||
| current_logps = torch.log(torch.full((B, T), 0.31)) | ||
| advantages = torch.tensor([[1.0]]) | ||
| completion_mask = torch.ones(B, T) | ||
|
|
||
| # Build top-K distributions that are nearly identical | ||
| topk_probs = torch.softmax(torch.randn(B, T, K), dim=-1) | ||
| sampling_topk_logps = torch.log(topk_probs) | ||
| current_topk_logps = torch.log(topk_probs + 0.001) | ||
|
|
||
| mask = DPPOTrainer._compute_divergence_mask( | ||
| stub, | ||
| current_logps, | ||
| sampling_logps, | ||
| advantages, | ||
| completion_mask, | ||
| current_topk_logps=current_topk_logps, | ||
| sampling_topk_logps=sampling_topk_logps, | ||
| ) | ||
| assert mask.shape == (B, T) | ||
| assert (mask == 1.0).all() | ||
|
|
||
|
|
||
| @pytest.mark.low_priority | ||
| class TestDPPOTrainer(TrlTestCase): | ||
| @pytest.mark.parametrize("divergence_type", ["binary_tv", "binary_kl"]) | ||
| def test_training_binary(self, divergence_type): | ||
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | ||
|
|
||
| config = DPPOConfig( | ||
LeonEricsson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, | ||
| per_device_train_batch_size=3, | ||
| num_generations=3, | ||
| max_completion_length=8, | ||
| divergence_type=divergence_type, | ||
| report_to="none", | ||
| ) | ||
| trainer = DPPOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| args=config, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
||
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | ||
|
|
||
| trainer.train() | ||
|
|
||
| assert trainer.state.log_history[-1]["train_loss"] is not None | ||
|
|
||
| for n, param in previous_trainable_params.items(): | ||
| new_param = trainer.model.get_parameter(n) | ||
| assert not torch.equal(param, new_param), f"Parameter {n} has not changed." | ||
|
|
||
| def test_training_with_custom_reward_func(self): | ||
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | ||
|
|
||
| def dummy_reward(completions, **kwargs): | ||
| return [float(len(c)) for c in completions] | ||
|
|
||
| config = DPPOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, | ||
| per_device_train_batch_size=3, | ||
| num_generations=3, | ||
| max_completion_length=8, | ||
| report_to="none", | ||
| ) | ||
| trainer = DPPOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
| reward_funcs=dummy_reward, | ||
| args=config, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
||
| trainer.train() | ||
|
|
||
| assert trainer.state.log_history[-1]["train_loss"] is not None | ||
LeonEricsson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) | ||
| def test_training_conversational(self, config_name): | ||
| dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") | ||
|
|
||
| config = DPPOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, | ||
| per_device_train_batch_size=3, | ||
| num_generations=3, | ||
| max_completion_length=8, | ||
| report_to="none", | ||
| ) | ||
| trainer = DPPOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| args=config, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
||
| trainer.train() | ||
|
|
||
| assert trainer.state.log_history[-1]["train_loss"] is not None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # Copyright 2020-2026 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| from .dppo_config import DPPOConfig | ||
| from .dppo_trainer import DPPOTrainer |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| # Copyright 2020-2026 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from typing import Literal | ||
|
|
||
| from ...trainer.grpo_config import GRPOConfig | ||
|
|
||
|
|
||
| @dataclass | ||
| class DPPOConfig(GRPOConfig): | ||
| """ | ||
| Configuration class for DPPOTrainer. | ||
|
|
||
| DPPO (Divergence Proximal Policy Optimization) replaces PPO/GRPO's heuristic ratio-clipping with a principled | ||
| trust region based on direct policy divergence estimates. | ||
|
|
||
| Paper: "Rethinking the Trust Region in LLM Reinforcement Learning" (arXiv:2602.04879) | ||
|
|
||
| Args: | ||
| divergence_type (`Literal["binary_tv", "binary_kl", "topk_tv", "topk_kl"]`, *optional*, defaults to `"binary_tv"`): | ||
| Divergence approximation used for the trust-region mask. Binary variants use only per-token log-probs; | ||
| top-K variants require storing top-K token IDs and log-probs during rollout generation plus full logits | ||
| during training. | ||
|
|
||
| divergence_topk (`int`, *optional*, defaults to `20`): | ||
| K for top-K divergence approximations. Only used when `divergence_type` starts with `"topk_"`. | ||
|
|
||
| clip_ratio_c (`float`, *optional*, defaults to `10.0`): | ||
| Upper bound on the importance-sampling ratio for stability. The IS ratio is clamped to [0, clip_ratio_c]. | ||
|
|
||
| epsilon (`float`, inherited from GRPOConfig, default overridden to `0.2`): | ||
LeonEricsson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Divergence threshold δ_low. Tokens whose divergence exceeds this when the policy moves in the | ||
| advantage-decreasing direction are masked. | ||
|
|
||
| epsilon_high (`float`, inherited from GRPOConfig, default overridden to `0.28`): | ||
| Divergence threshold δ_high. Tokens whose divergence exceeds this when the policy moves in the | ||
| advantage-increasing direction are masked. The paper recommends asymmetric thresholds. | ||
| """ | ||
|
|
||
| divergence_type: Literal["binary_tv", "binary_kl", "topk_tv", "topk_kl"] = field( | ||
| default="binary_tv", | ||
| metadata={ | ||
| "help": "Divergence approximation for the trust-region mask. 'binary_tv': absolute probability " | ||
| "difference. 'binary_kl': Bernoulli KL divergence. 'topk_tv': TV over top-K tokens. " | ||
| "'topk_kl': KL over top-K tokens." | ||
| }, | ||
| ) | ||
| divergence_topk: int = field( | ||
| default=20, | ||
| metadata={ | ||
| "help": "K for top-K divergence approximations. Only used when divergence_type starts with 'topk_'." | ||
| }, | ||
| ) | ||
| clip_ratio_c: float = field( | ||
| default=10.0, | ||
| metadata={"help": "Upper bound on the importance-sampling ratio for stability."}, | ||
| ) | ||
| epsilon: float = field( | ||
| default=0.2, | ||
| metadata={"help": "Divergence threshold δ_low for the trust-region mask."}, | ||
| ) | ||
| epsilon_high: float = field( | ||
| default=0.28, | ||
| metadata={"help": "Divergence threshold δ_high for the trust-region mask (asymmetric)."}, | ||
| ) | ||
|
|
||
| def __post_init__(self): | ||
| super().__post_init__() | ||
|
|
||
| if self.divergence_type not in ("binary_tv", "binary_kl", "topk_tv", "topk_kl"): | ||
| raise ValueError( | ||
| f"divergence_type must be one of 'binary_tv', 'binary_kl', 'topk_tv', 'topk_kl', " | ||
| f"got {self.divergence_type!r}" | ||
| ) | ||
|
|
||
| if self.divergence_topk < 1: | ||
| raise ValueError(f"divergence_topk must be >= 1, got {self.divergence_topk}") | ||
|
|
||
| if self.clip_ratio_c <= 0: | ||
| raise ValueError(f"clip_ratio_c must be > 0, got {self.clip_ratio_c}") | ||
|
|
||
| if self.loss_type != "dapo": | ||
LeonEricsson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| raise ValueError("loss_type {self.loss_type} is not supported for DPPO") | ||
|
|
||
| if self.top_entropy_quantile != 1.0: | ||
| raise ValueError("top_entropy_quantile is not supported for DPPO") | ||
|
|
||
| if self.off_policy_mask_threshold is not None: | ||
| raise ValueError("off_policy_mask_threshold is not supported for DPPO") | ||
|
|
||
| if self.use_transformers_paged: | ||
| raise ValueError( | ||
| "DPPO requires sampled token logprobs from the generation backend. " | ||
| "Transformers paged (`use_transformers_paged=True`) does not support logprob extraction." | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.