Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ checkpoint_root_dir: ''
algorithm:
algorithm_type: ppo
repeat_times: 1
policy_loss_fn: ppo
policy_loss_fn_args:
clip_range: 0.2
model:
model_path: ''
max_prompt_tokens: 2048
Expand Down
2 changes: 1 addition & 1 deletion trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn

__all__ = [
"AdvantageFn",
Expand Down
14 changes: 14 additions & 0 deletions trinity/algorithm/policy_loss_fn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn
from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn

__all__ = [
"POLICY_LOSS_FN",
"PolicyLossFn",
"PPOPolicyLossFn",
"OPMDPolicyLossFn",
"DPOLossFn",
"SFTLossFn",
]
67 changes: 67 additions & 0 deletions trinity/algorithm/policy_loss_fn/dpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""DPO loss function."""

from typing import Any, Dict, Tuple

import torch
import torch.nn.functional as F

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_sum


@POLICY_LOSS_FN.register_module("dpo")
class DPOLossFn(PolicyLossFn):
def __init__(
self,
beta: float = 0.1,
label_smoothing: float = 0.0,
) -> None:
self.beta = beta
self.label_smoothing = label_smoothing

def __call__(
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
chosen_logprob = logprob[::2]
rejected_logprob = logprob[1::2]
chosen_mask = action_mask[::2]
rejected_mask = action_mask[1::2]
chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask)
rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask)

chosen_ref_logprob = old_logprob[::2]
rejected_ref_logprob = old_logprob[1::2]
chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask)
rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask)

chosen_ratios = chosen_logprob_sum - chosen_ref_logprob_sum
rejected_ratios = rejected_logprob_sum - rejected_ref_logprob_sum
logits = chosen_ratios - rejected_ratios
# TODO: support other loss functions
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
loss = losses.mean()
chosen_reward = self.beta * chosen_ratios.detach().mean().item()
rejected_reward = self.beta * rejected_ratios.detach().mean().item()
accuracy_mean = (chosen_ratios.detach() > rejected_ratios.detach()).float().mean().item()
return loss, {
"chosen_reward": chosen_reward,
"rejected_reward": rejected_reward,
"accuracy_mean": accuracy_mean,
"dpo_loss": loss.detach().item(),
}

@classmethod
def default_args(cls) -> Dict:
return {
"beta": 0.1,
"label_smoothing": 0.0,
}
35 changes: 35 additions & 0 deletions trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""PPO policy loss function.

Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""

from typing import Any, Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean


@POLICY_LOSS_FN.register_module("opmd")
class OPMDPolicyLossFn(PolicyLossFn):
def __init__(self, tau: float = 1.0) -> None:
self.tau = tau

def __call__(
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
pg_losses = -advantages * logprob
opmd_loss = masked_mean(pg_losses, action_mask)
opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta)
return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}

@classmethod
def default_args(cls) -> Dict:
return {"tau": 1.0}
8 changes: 8 additions & 0 deletions trinity/algorithm/policy_loss_fn/policy_loss_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ def __call__(
`torch.Tensor`: Policy loss
`Dict`: The metrics for logging.
"""

@classmethod
@abstractmethod
def default_args(cls) -> Dict:
"""
Returns:
`Dict`: The default init arguments for the policy loss function.
"""
64 changes: 64 additions & 0 deletions trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""PPO policy loss function.

Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""

from typing import Any, Dict, Optional, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean


@POLICY_LOSS_FN.register_module("ppo")
class PPOPolicyLossFn(PolicyLossFn):
def __init__(
self,
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
) -> None:
if clip_range_low is None:
self.clip_range_low = clip_range
else:
self.clip_range_low = clip_range_low
if clip_range_high is None:
self.clip_range_high = clip_range
else:
self.clip_range_high = clip_range_high
assert self.clip_range_low is not None, "clip_range_low must be specified."
assert self.clip_range_high is not None, "clip_range_high must be specified."

def __call__(
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
negative_approx_kl = logprob - old_logprob
ratio = torch.exp(negative_approx_kl)
ppo_kl = masked_mean(-negative_approx_kl, action_mask)

pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
)

pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), action_mask)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
metrics = {
"pg_clipfrac": pg_clipfrac.detach().item(),
"ppo_kl": ppo_kl.detach().item(),
"pg_loss": pg_loss.detach().item(),
}
return pg_loss, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"clip_range": 0.2,
}
35 changes: 35 additions & 0 deletions trinity/algorithm/policy_loss_fn/sft_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""SFT loss function."""

from typing import Any, Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean


@POLICY_LOSS_FN.register_module("sft")
class SFTLossFn(PolicyLossFn):
def __init__(self, use_token_level_loss: bool = True) -> None:
self.use_token_level_loss = use_token_level_loss

def __call__(
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
if self.use_token_level_loss:
sft_loss = masked_mean(-logprob, action_mask)
else:
sft_loss = masked_mean(-logprob, action_mask, axis=1).mean()
return sft_loss, {"sft_loss": sft_loss.detach().item()}

@classmethod
def default_args(cls):
return {
"use_token_level_loss": True,
}
14 changes: 14 additions & 0 deletions trinity/algorithm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Common utils for algorithm module.

Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py
"""


def masked_sum(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
return (values * mask).sum(axis=axis)


def masked_mean(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8)
17 changes: 16 additions & 1 deletion trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ class AlgorithmConfig:
repeat_times: int = 1
gamma: Optional[float] = None
lam: Optional[float] = None
# TODO: add more algorithm params here

policy_loss_fn: str = "ppo"
# If not set, use PolicyLossFn.default_args()
policy_loss_fn_args: Optional[dict] = None


@dataclass
Expand Down Expand Up @@ -466,6 +469,15 @@ def _check_buffer(self) -> None: # noqa: C901
self.buffer.pad_token_id = 0
self.buffer.tokenizer_path = self.model.model_path

def _check_algorithm(self) -> None:
from trinity.algorithm import POLICY_LOSS_FN

policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn)
if policy_fn_cls is None:
raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}")
if self.algorithm.policy_loss_fn_args is None:
self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args()

def check_and_update(self) -> None: # noqa: C901
"""Check and update the config."""
self._check_deprecated()
Expand All @@ -489,6 +501,9 @@ def check_and_update(self) -> None: # noqa: C901
if not self.model.critic_model_path:
self.model.critic_model_path = self.model.model_path

# check algorithm
self._check_algorithm()

# check explorer
if (
self.explorer.rollout_model.engine_type != "vllm_async"
Expand Down
4 changes: 4 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ class Algorithm:
kl_penalty: str = "kl"
kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl)

# ! DO NOT SET THE FLOWING PARAMETERS
policy_loss_fn: str = "ppo"
policy_loss_fn_args: Optional[dict] = None


@dataclass
class Trainer:
Expand Down
8 changes: 4 additions & 4 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import ray

from trinity.buffer import get_buffer_reader
from trinity.common.config import Config
from trinity.common.config import AlgorithmConfig, Config
from trinity.common.constants import AlgorithmType, SyncMethod
from trinity.common.experience import Experiences
from trinity.utils.log import get_logger
Expand Down Expand Up @@ -73,7 +73,7 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool
Returns:
bool: Whether to continue training.
"""
self.engine.set_mode(algo_type)
self.engine.set_algorithm(self.config.algorithm)
if algo_type.is_rft() and self.config.buffer.trainer_input.read_experience_strategy:
strategy = self.config.buffer.trainer_input.read_experience_strategy
else:
Expand Down Expand Up @@ -157,8 +157,8 @@ def sync_weight(self) -> None:
"""Sync the model weight."""

@abstractmethod
def set_mode(self, algo_type: AlgorithmType) -> None:
"""Set training mode."""
def set_algorithm(self, algorithm_config: AlgorithmConfig) -> None:
"""Set training algorithm config."""

@abstractmethod
def shutdown(self) -> None:
Expand Down
Loading