Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/sphinx_doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
"sphinx.ext.napoleon",
"sphinx.ext.autosectionlabel",
"myst_parser",
"sphinx.ext.mathjax",
]
source_suffix = {
".rst": "restructuredtext",
".md": "markdown",
}
myst_enable_extensions = ["colon_fence"]
myst_enable_extensions = ["colon_fence", "amsmath", "dollarmath"]

# Prefix document path to section labels, otherwise autogenerated labels would
# look like 'heading' rather than 'path/to/file:heading'
Expand Down
1 change: 1 addition & 0 deletions docs/sphinx_doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Welcome to Trinity-RFT's documentation!
tutorial/example_data_functionalities.md
tutorial/trinity_configs.md
tutorial/trinity_programming_guide.md
tutorial/example_mix_algo.md

.. toctree::
:maxdepth: 1
Expand Down
12 changes: 4 additions & 8 deletions docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Integrate An New Algorithm
# Integrate A New Algorithm


This guide introduces how to integrate a new algorithm to Trinity-RFT.
Expand Down Expand Up @@ -170,6 +170,7 @@ We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_polic
class MIXPolicyLossFn(PolicyLossFn):
def __init__(
self,
backend: str = "verl",
mu: float = 0.1,
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
Expand All @@ -183,6 +184,7 @@ class MIXPolicyLossFn(PolicyLossFn):
read_batch_size_expert: Optional[int] = None,
use_token_level_loss_in_sft: bool = True,
) -> None:
super().__init__(backend=backend)
self.mu = mu
self.use_dynamic_bsz = use_dynamic_bsz
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
Expand All @@ -204,11 +206,9 @@ class MIXPolicyLossFn(PolicyLossFn):
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
is_expert_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
is_expert_mask = kwargs.get("is_expert_mask", None)
if is_expert_mask is None:
raise ValueError("is_expert_mask is required in MIX")
assert (
len(is_expert_mask) == logprob.shape[0]
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
Expand Down Expand Up @@ -271,10 +271,6 @@ class MIXPolicyLossFn(PolicyLossFn):
"mu": 0.1,
"clip_range": 0.2,
}

@property
def select_keys(self) -> List[str]:
return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
```

## Step 4: Run the Experiment
Expand Down
2 changes: 1 addition & 1 deletion examples/mix_math/mix_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/mix_math/train_math.yaml'
trainer_config_path: 'examples/mix_math/train_mix_math.yaml'
save_interval: 50
monitor:
monitor_type: wandb
28 changes: 28 additions & 0 deletions trinity/algorithm/key_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
"""Key Mapper"""

from typing import Dict


class KeyMapper:
def __init__(self, to_trinity_map: Dict[str, str]):
self.to_trinity_map = to_trinity_map
self.from_trinity_map = {v: k for k, v in self.to_trinity_map.items()}

def to_trinity(self, key: str) -> str:
return self.to_trinity_map.get(key, key)

def from_trinity(self, key: str) -> str:
return self.from_trinity_map.get(key, key)


ALL_MAPPERS = {
"verl": KeyMapper(
{
"old_log_probs": "old_logprob",
"ref_log_prob": "ref_logprob",
"response_mask": "action_mask",
"advantages": "advantages",
}
),
}
11 changes: 3 additions & 8 deletions trinity/algorithm/policy_loss_fn/dpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""DPO loss function."""

from typing import Dict, List, Tuple
from typing import Dict, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -13,9 +13,11 @@
class DPOLossFn(PolicyLossFn):
def __init__(
self,
backend: str = "verl",
beta: float = 0.1,
label_smoothing: float = 0.0,
) -> None:
super().__init__(backend=backend)
self.beta = beta
self.label_smoothing = label_smoothing

Expand Down Expand Up @@ -63,10 +65,3 @@ def default_args(cls) -> Dict:
"beta": 0.1,
"label_smoothing": 0.0,
}

@property
def select_keys(self) -> List[str]:
return [
"ref_logprob",
"action_mask",
]
12 changes: 4 additions & 8 deletions trinity/algorithm/policy_loss_fn/mix_policy_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Mix policy loss function."""

from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple

import torch

Expand All @@ -26,6 +26,7 @@ class MIXPolicyLossFn(PolicyLossFn):

def __init__(
self,
backend: str = "verl",
mu: float = 0.1,
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
Expand All @@ -39,6 +40,7 @@ def __init__(
read_batch_size_expert: Optional[int] = None,
use_token_level_loss_in_sft: bool = True,
) -> None:
super().__init__(backend=backend)
self.mu = mu
self.use_dynamic_bsz = use_dynamic_bsz
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
Expand All @@ -60,11 +62,9 @@ def __call__( # type: ignore
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
is_expert_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
is_expert_mask = kwargs.get("is_expert_mask", None)
if is_expert_mask is None:
raise ValueError("is_expert_mask is required in MIX")
assert (
len(is_expert_mask) == logprob.shape[0]
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
Expand Down Expand Up @@ -127,7 +127,3 @@ def default_args(cls) -> Dict:
"mu": 0.1,
"clip_range": 0.2,
}

@property
def select_keys(self) -> List[str]:
return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
13 changes: 3 additions & 10 deletions trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""OPMD policy loss function."""

from typing import Dict, List, Tuple
from typing import Dict, Tuple

import torch

Expand All @@ -10,7 +10,8 @@

@POLICY_LOSS_FN.register_module("opmd")
class OPMDPolicyLossFn(PolicyLossFn):
def __init__(self, tau: float = 1.0) -> None:
def __init__(self, backend: str = "verl", tau: float = 1.0) -> None:
super().__init__(backend=backend)
self.tau = tau

def __call__( # type: ignore
Expand All @@ -29,11 +30,3 @@ def __call__( # type: ignore
@classmethod
def default_args(cls) -> Dict:
return {"tau": 1.0}

@property
def select_keys(self) -> List[str]:
return [
"old_logprob",
"action_mask",
"advantages",
]
53 changes: 42 additions & 11 deletions trinity/algorithm/policy_loss_fn/policy_loss_fn.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,57 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
import inspect
from abc import ABC, ABCMeta, abstractmethod
from typing import Dict, Tuple

import torch

from trinity.algorithm.key_mapper import ALL_MAPPERS
from trinity.utils.registry import Registry

POLICY_LOSS_FN = Registry("policy_loss_fn")


class PolicyLossFn(ABC):
class PolicyLossFnMeta(ABCMeta):
"""Meta class for policy loss function."""

ignore_keys = {"self", "kwargs", "logprob"}

def __new__(cls, name, bases, dct):
signature = inspect.signature(dct["__call__"])
param_names = [
key for key in signature.parameters.keys() if key not in PolicyLossFnMeta.ignore_keys
]
dct["_select_keys"] = param_names

def select_keys(self):
keys = [self.mapper.from_trinity(key) for key in self._select_keys]
return keys

def decorator(func):
def wrapper(self, *args, **kwargs):
new_kwargs = {}
for key, value in kwargs.items():
key = self.mapper.to_trinity(key)
if key in self._select_keys: # remove unused keys
new_kwargs[key] = value
kwargs = new_kwargs
return func(self, *args, **new_kwargs)

return wrapper

dct["select_keys"] = property(select_keys)
dct["__call__"] = decorator(dct["__call__"])
return super().__new__(cls, name, bases, dct)


class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta):
"""
Policy Loss Function
"""

def __init__(self, backend: str = "verl"):
self.backend = backend
self.mapper = ALL_MAPPERS[self.backend]

@abstractmethod
def __call__(
self,
Expand All @@ -39,11 +78,3 @@ def default_args(cls) -> Dict:
Returns:
`Dict`: The default init arguments for the policy loss function.
"""

@property
@abstractmethod
def select_keys(self) -> List[str]:
"""
Returns:
`List[str]`: The keys to select from input data.
"""
12 changes: 3 additions & 9 deletions trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""

from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple

import torch

Expand All @@ -15,10 +15,12 @@
class PPOPolicyLossFn(PolicyLossFn):
def __init__(
self,
backend: str = "verl",
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
) -> None:
super().__init__(backend=backend)
if clip_range_low is None:
self.clip_range_low = clip_range
else:
Expand Down Expand Up @@ -61,11 +63,3 @@ def default_args(cls) -> Dict:
return {
"clip_range": 0.2,
}

@property
def select_keys(self) -> List[str]:
return [
"old_logprob",
"action_mask",
"advantages",
]
9 changes: 3 additions & 6 deletions trinity/algorithm/policy_loss_fn/sft_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""SFT loss function."""

from typing import Dict, List, Tuple
from typing import Dict, Tuple

import torch

Expand All @@ -10,7 +10,8 @@

@POLICY_LOSS_FN.register_module("sft")
class SFTLossFn(PolicyLossFn):
def __init__(self, use_token_level_loss: bool = True) -> None:
def __init__(self, backend: str = "verl", use_token_level_loss: bool = True) -> None:
super().__init__(backend=backend)
self.use_token_level_loss = use_token_level_loss

def __call__( # type: ignore
Expand All @@ -30,7 +31,3 @@ def default_args(cls):
return {
"use_token_level_loss": True,
}

@property
def select_keys(self) -> List[str]:
return ["action_mask"]
26 changes: 3 additions & 23 deletions trinity/trainer/verl/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(

def set_algorithm(self, algorithm_config: AlgorithmConfig):
self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)(
**algorithm_config.policy_loss_fn_args
backend="verl", **algorithm_config.policy_loss_fn_args
)
self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args)
self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)(
Expand Down Expand Up @@ -152,21 +152,7 @@ def update_policy(self, data: DataProto):
"responses",
"response_mask",
]
select_keys_verl2trinity = {
"old_log_probs": "old_logprob",
"ref_log_prob": "ref_logprob",
"response_mask": "action_mask",
"advantages": "advantages",
}
select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()}
for trinity_key in self.policy_loss_fn.select_keys:
if trinity_key in select_keys_trinity2verl:
verl_key = select_keys_trinity2verl[trinity_key]
else:
verl_key = trinity_key
select_keys_verl2trinity.update({verl_key: trinity_key})
select_keys_trinity2verl.update({trinity_key: verl_key})
select_keys.append(verl_key)
select_keys.extend(self.policy_loss_fn.select_keys)
if not isinstance(self.kl_loss_fn, DummyKLFn):
select_keys.append("ref_log_prob")
select_keys = list(set(select_keys))
Expand Down Expand Up @@ -240,14 +226,8 @@ def update_policy(self, data: DataProto):
calculate_entropy=calculate_entropy,
)

kwargs = {
select_keys_verl2trinity[verl_key]: value
for verl_key, value in data.items()
if verl_key in select_keys_verl2trinity
}
pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore
logprob=log_prob,
**kwargs,
logprob=log_prob, **data
)
prefix_metrics(
src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics
Expand Down