Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/sphinx_doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
"sphinx.ext.napoleon",
"sphinx.ext.autosectionlabel",
"myst_parser",
"sphinx.ext.mathjax",
]
source_suffix = {
".rst": "restructuredtext",
".md": "markdown",
}
myst_enable_extensions = ["colon_fence"]
myst_enable_extensions = ["colon_fence", "amsmath", "dollarmath"]

# Prefix document path to section labels, otherwise autogenerated labels would
# look like 'heading' rather than 'path/to/file:heading'
Expand Down
1 change: 1 addition & 0 deletions docs/sphinx_doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Welcome to Trinity-RFT's documentation!
tutorial/example_data_functionalities.md
tutorial/trinity_configs.md
tutorial/trinity_programming_guide.md
tutorial/example_mix_algo.md

.. toctree::
:maxdepth: 1
Expand Down
14 changes: 5 additions & 9 deletions docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Integrate An New Algorithm
# Integrate A New Algorithm


This guide introduces how to integrate a new algorithm to Trinity-RFT.
As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective:

$$
\mathcal{J}_{\text{Mix}}(\theta) =
\mathcal{J}_{\text{GRPO}}(\theta)
(1-\mu) \mathcal{J}_{\text{GRPO}}(\theta)
+
\mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'}
\left[
Expand Down Expand Up @@ -170,6 +170,7 @@ We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_polic
class MIXPolicyLossFn(PolicyLossFn):
def __init__(
self,
backend: str = "verl",
mu: float = 0.1,
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
Expand All @@ -183,6 +184,7 @@ class MIXPolicyLossFn(PolicyLossFn):
read_batch_size_expert: Optional[int] = None,
use_token_level_loss_in_sft: bool = True,
) -> None:
super().__init__(backend=backend)
self.mu = mu
self.use_dynamic_bsz = use_dynamic_bsz
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
Expand All @@ -204,11 +206,9 @@ class MIXPolicyLossFn(PolicyLossFn):
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
is_expert_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
is_expert_mask = kwargs.get("is_expert_mask", None)
if is_expert_mask is None:
raise ValueError("is_expert_mask is required in MIX")
assert (
len(is_expert_mask) == logprob.shape[0]
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
Expand Down Expand Up @@ -271,10 +271,6 @@ class MIXPolicyLossFn(PolicyLossFn):
"mu": 0.1,
"clip_range": 0.2,
}

@property
def select_keys(self) -> List[str]:
return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
```

## Step 4: Run the Experiment
Expand Down
2 changes: 1 addition & 1 deletion examples/mix_math/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ This example shows the usage of a new algorithm MIX on the MATH dataset.

For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md).

The config files are located in [`mix_math.yaml`](mix.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml).
The config files are located in [`mix_math.yaml`](mix_math.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml).
3 changes: 1 addition & 2 deletions examples/mix_math/mix_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 40
explore_batch_size: 36
max_retry_times: 3
max_retry_interval: 1
explorer_input:
Expand Down Expand Up @@ -82,7 +81,7 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/mix_math/train_math.yaml'
trainer_config_path: 'examples/mix_math/train_mix_math.yaml'
save_interval: 50
monitor:
monitor_type: wandb
94 changes: 94 additions & 0 deletions tests/algorithm/policy_loss_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
"""Test for policy loss functions"""

import unittest

import torch
from verl import DataProto

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN


class VerlPolicyLossTest(unittest.TestCase):
def setUp(self):
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

shape = (5, 20)
self.logprob = 2 * torch.rand(shape) - 1
self.input_data = DataProto.from_dict(
{
"old_log_probs": 2 * torch.rand(shape) - 1,
"ref_log_prob": 2 * torch.rand(shape) - 1,
"response_mask": torch.rand(shape) > 0.5,
"advantages": 2 * torch.rand(shape) - 1,
"is_expert_mask": torch.rand(shape[0]) > 0.5,
}
)

def test_ppo_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
ppo_loss = torch.tensor(0.28560468554496765)
pg_clipfrac = torch.tensor(0.3541666567325592)
ppo_kl = torch.tensor(-0.21663446724414825)
self.assertTrue(torch.allclose(loss, ppo_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss))

def test_sft_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("sft")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
sft_loss = torch.tensor(-0.07560186833143234)
self.assertTrue(torch.allclose(loss, sft_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["sft_loss"]), sft_loss))

def test_dpo_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("dpo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
dpo_loss = torch.tensor(0.5406752228736877)
chosen_reward = torch.tensor(0.7082431316375732)
rejected_reward = torch.tensor(0.3757950782775879)
accuracy_mean = torch.tensor(1.0)
self.assertTrue(torch.allclose(loss, dpo_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["chosen_reward"]), chosen_reward))
self.assertTrue(torch.allclose(torch.tensor(metrics["rejected_reward"]), rejected_reward))
self.assertTrue(torch.allclose(torch.tensor(metrics["accuracy_mean"]), accuracy_mean))
self.assertTrue(torch.allclose(torch.tensor(metrics["dpo_loss"]), dpo_loss))

def test_opmd_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("opmd")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
opmd_loss = torch.tensor(-0.009589947760105133)
self.assertTrue(torch.allclose(loss, opmd_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["opmd_loss"]), opmd_loss))

def test_mix_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("mix")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
mix_loss = torch.tensor(0.6581965088844299)
pg_clipfrac = torch.tensor(0.7777777910232544)
ppo_kl = torch.tensor(-1.0737695693969727)
pg_loss = torch.tensor(0.7236452102661133)
sft_loss = torch.tensor(0.06915830634534359)
self.assertTrue(torch.allclose(loss, mix_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac))
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl))
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))
1 change: 1 addition & 0 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_all_examples_are_valid(self):
config_path = os.path.join(example_dir, example_name, filename)
try:
config = load_config(config_path)
config.checkpoint_root_dir = "./.cache/"
config.check_and_update()
except Exception as e:
print(f"Error loading config {config_path}: {e}")
Expand Down
29 changes: 29 additions & 0 deletions trinity/algorithm/key_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
"""Key Mapper"""

from typing import Dict


class KeyMapper:
def __init__(self, to_trinity_map: Dict[str, str]):
self.to_trinity_map = to_trinity_map
self.from_trinity_map = {v: k for k, v in self.to_trinity_map.items()}

def to_trinity(self, key: str) -> str:
return self.to_trinity_map.get(key, key)

def from_trinity(self, key: str) -> str:
return self.from_trinity_map.get(key, key)


ALL_MAPPERS = {
"verl": KeyMapper(
{
"log_prob": "logprob",
"old_log_probs": "old_logprob",
"ref_log_prob": "ref_logprob",
"response_mask": "action_mask",
"advantages": "advantages",
}
),
}
11 changes: 3 additions & 8 deletions trinity/algorithm/policy_loss_fn/dpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""DPO loss function."""

from typing import Dict, List, Tuple
from typing import Dict, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -13,9 +13,11 @@
class DPOLossFn(PolicyLossFn):
def __init__(
self,
backend: str = "verl",
beta: float = 0.1,
label_smoothing: float = 0.0,
) -> None:
super().__init__(backend=backend)
self.beta = beta
self.label_smoothing = label_smoothing

Expand Down Expand Up @@ -63,10 +65,3 @@ def default_args(cls) -> Dict:
"beta": 0.1,
"label_smoothing": 0.0,
}

@property
def select_keys(self) -> List[str]:
return [
"ref_logprob",
"action_mask",
]
32 changes: 14 additions & 18 deletions trinity/algorithm/policy_loss_fn/mix_policy_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Mix policy loss function."""

from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple

import torch

Expand All @@ -26,27 +26,29 @@ class MIXPolicyLossFn(PolicyLossFn):

def __init__(
self,
backend: str = "verl",
mu: float = 0.1,
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
use_dynamic_bsz: Optional[bool] = None,
repeat_times: Optional[int] = None,
ppo_mini_batch_size: Optional[int] = None,
ppo_micro_batch_size_per_gpu: Optional[int] = None,
ngpus_trainer: Optional[int] = None,
read_batch_size_usual: Optional[int] = None,
read_batch_size_expert: Optional[int] = None,
repeat_times: int = 1,
ppo_mini_batch_size: int = 1,
ppo_micro_batch_size_per_gpu: int = 1,
ngpus_trainer: int = 1,
read_batch_size_usual: int = 1,
read_batch_size_expert: int = 1,
use_token_level_loss_in_sft: bool = True,
) -> None:
super().__init__(backend=backend)
self.mu = mu
self.use_dynamic_bsz = use_dynamic_bsz
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer
self.gradient_accumulation = (
ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore
ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu
)
self.read_batch_size_usual = read_batch_size_usual
self.read_batch_size_expert = read_batch_size_expert
self.read_batch_size_usual = read_batch_size_usual // ngpus_trainer
self.read_batch_size_expert = read_batch_size_expert // ngpus_trainer
self.grpo_loss_fn = PPOPolicyLossFn(
clip_range=clip_range,
clip_range_low=clip_range_low,
Expand All @@ -60,11 +62,9 @@ def __call__( # type: ignore
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
is_expert_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
is_expert_mask = kwargs.get("is_expert_mask", None)
if is_expert_mask is None:
raise ValueError("is_expert_mask is required in MIX")
assert (
len(is_expert_mask) == logprob.shape[0]
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
Expand Down Expand Up @@ -127,7 +127,3 @@ def default_args(cls) -> Dict:
"mu": 0.1,
"clip_range": 0.2,
}

@property
def select_keys(self) -> List[str]:
return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
14 changes: 3 additions & 11 deletions trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""OPMD policy loss function."""

from typing import Dict, List, Tuple
from typing import Dict, Tuple

import torch

Expand All @@ -10,13 +10,13 @@

@POLICY_LOSS_FN.register_module("opmd")
class OPMDPolicyLossFn(PolicyLossFn):
def __init__(self, tau: float = 1.0) -> None:
def __init__(self, backend: str = "verl", tau: float = 1.0) -> None:
super().__init__(backend=backend)
self.tau = tau

def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor, # NOT USED!
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
Expand All @@ -29,11 +29,3 @@ def __call__( # type: ignore
@classmethod
def default_args(cls) -> Dict:
return {"tau": 1.0}

@property
def select_keys(self) -> List[str]:
return [
"old_logprob",
"action_mask",
"advantages",
]
Loading