Skip to content

Commit 5cd6cb6

Browse files
authored
Add Policy Loss Functions (#62)
1 parent d7c43fe commit 5cd6cb6

File tree

16 files changed

+338
-206
lines changed

16 files changed

+338
-206
lines changed

tests/template/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ checkpoint_root_dir: ''
55
algorithm:
66
algorithm_type: ppo
77
repeat_times: 1
8+
policy_loss_fn: ppo
9+
policy_loss_fn_args:
10+
clip_range: 0.2
811
model:
912
model_path: ''
1013
max_prompt_tokens: 2048

trinity/algorithm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
2-
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
2+
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
33

44
__all__ = [
55
"AdvantageFn",
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn
2+
from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn
3+
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
4+
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
5+
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
6+
7+
__all__ = [
8+
"POLICY_LOSS_FN",
9+
"PolicyLossFn",
10+
"PPOPolicyLossFn",
11+
"OPMDPolicyLossFn",
12+
"DPOLossFn",
13+
"SFTLossFn",
14+
]
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""DPO loss function."""
2+
3+
from typing import Any, Dict, Tuple
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
9+
from trinity.algorithm.utils import masked_sum
10+
11+
12+
@POLICY_LOSS_FN.register_module("dpo")
13+
class DPOLossFn(PolicyLossFn):
14+
def __init__(
15+
self,
16+
beta: float = 0.1,
17+
label_smoothing: float = 0.0,
18+
) -> None:
19+
self.beta = beta
20+
self.label_smoothing = label_smoothing
21+
22+
def __call__(
23+
self,
24+
logprob: torch.Tensor,
25+
old_logprob: torch.Tensor,
26+
action_mask: torch.Tensor,
27+
advantages: torch.Tensor,
28+
experiences: Any,
29+
**kwargs,
30+
) -> Tuple[torch.Tensor, Dict]:
31+
chosen_logprob = logprob[::2]
32+
rejected_logprob = logprob[1::2]
33+
chosen_mask = action_mask[::2]
34+
rejected_mask = action_mask[1::2]
35+
chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask)
36+
rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask)
37+
38+
chosen_ref_logprob = old_logprob[::2]
39+
rejected_ref_logprob = old_logprob[1::2]
40+
chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask)
41+
rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask)
42+
43+
chosen_ratios = chosen_logprob_sum - chosen_ref_logprob_sum
44+
rejected_ratios = rejected_logprob_sum - rejected_ref_logprob_sum
45+
logits = chosen_ratios - rejected_ratios
46+
# TODO: support other loss functions
47+
losses = (
48+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
49+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
50+
)
51+
loss = losses.mean()
52+
chosen_reward = self.beta * chosen_ratios.detach().mean().item()
53+
rejected_reward = self.beta * rejected_ratios.detach().mean().item()
54+
accuracy_mean = (chosen_ratios.detach() > rejected_ratios.detach()).float().mean().item()
55+
return loss, {
56+
"chosen_reward": chosen_reward,
57+
"rejected_reward": rejected_reward,
58+
"accuracy_mean": accuracy_mean,
59+
"dpo_loss": loss.detach().item(),
60+
}
61+
62+
@classmethod
63+
def default_args(cls) -> Dict:
64+
return {
65+
"beta": 0.1,
66+
"label_smoothing": 0.0,
67+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""PPO policy loss function.
2+
3+
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
4+
"""
5+
6+
from typing import Any, Dict, Tuple
7+
8+
import torch
9+
10+
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
11+
from trinity.algorithm.utils import masked_mean
12+
13+
14+
@POLICY_LOSS_FN.register_module("opmd")
15+
class OPMDPolicyLossFn(PolicyLossFn):
16+
def __init__(self, tau: float = 1.0) -> None:
17+
self.tau = tau
18+
19+
def __call__(
20+
self,
21+
logprob: torch.Tensor,
22+
old_logprob: torch.Tensor,
23+
action_mask: torch.Tensor,
24+
advantages: torch.Tensor,
25+
experiences: Any,
26+
**kwargs,
27+
) -> Tuple[torch.Tensor, Dict]:
28+
pg_losses = -advantages * logprob
29+
opmd_loss = masked_mean(pg_losses, action_mask)
30+
opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta)
31+
return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}
32+
33+
@classmethod
34+
def default_args(cls) -> Dict:
35+
return {"tau": 1.0}

trinity/algorithm/policy_loss_fn/policy_loss_fn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,11 @@ def __call__(
3636
`torch.Tensor`: Policy loss
3737
`Dict`: The metrics for logging.
3838
"""
39+
40+
@classmethod
41+
@abstractmethod
42+
def default_args(cls) -> Dict:
43+
"""
44+
Returns:
45+
`Dict`: The default init arguments for the policy loss function.
46+
"""
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""PPO policy loss function.
2+
3+
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
4+
"""
5+
6+
from typing import Any, Dict, Optional, Tuple
7+
8+
import torch
9+
10+
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
11+
from trinity.algorithm.utils import masked_mean
12+
13+
14+
@POLICY_LOSS_FN.register_module("ppo")
15+
class PPOPolicyLossFn(PolicyLossFn):
16+
def __init__(
17+
self,
18+
clip_range: Optional[float] = None,
19+
clip_range_low: Optional[float] = None,
20+
clip_range_high: Optional[float] = None,
21+
) -> None:
22+
if clip_range_low is None:
23+
self.clip_range_low = clip_range
24+
else:
25+
self.clip_range_low = clip_range_low
26+
if clip_range_high is None:
27+
self.clip_range_high = clip_range
28+
else:
29+
self.clip_range_high = clip_range_high
30+
assert self.clip_range_low is not None, "clip_range_low must be specified."
31+
assert self.clip_range_high is not None, "clip_range_high must be specified."
32+
33+
def __call__(
34+
self,
35+
logprob: torch.Tensor,
36+
old_logprob: torch.Tensor,
37+
action_mask: torch.Tensor,
38+
advantages: torch.Tensor,
39+
experiences: Any,
40+
**kwargs,
41+
) -> Tuple[torch.Tensor, Dict]:
42+
negative_approx_kl = logprob - old_logprob
43+
ratio = torch.exp(negative_approx_kl)
44+
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
45+
46+
pg_losses = -advantages * ratio
47+
pg_losses2 = -advantages * torch.clamp(
48+
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
49+
)
50+
51+
pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), action_mask)
52+
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
53+
metrics = {
54+
"pg_clipfrac": pg_clipfrac.detach().item(),
55+
"ppo_kl": ppo_kl.detach().item(),
56+
"pg_loss": pg_loss.detach().item(),
57+
}
58+
return pg_loss, metrics
59+
60+
@classmethod
61+
def default_args(cls) -> Dict:
62+
return {
63+
"clip_range": 0.2,
64+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""SFT loss function."""
2+
3+
from typing import Any, Dict, Tuple
4+
5+
import torch
6+
7+
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
8+
from trinity.algorithm.utils import masked_mean
9+
10+
11+
@POLICY_LOSS_FN.register_module("sft")
12+
class SFTLossFn(PolicyLossFn):
13+
def __init__(self, use_token_level_loss: bool = True) -> None:
14+
self.use_token_level_loss = use_token_level_loss
15+
16+
def __call__(
17+
self,
18+
logprob: torch.Tensor,
19+
old_logprob: torch.Tensor,
20+
action_mask: torch.Tensor,
21+
advantages: torch.Tensor,
22+
experiences: Any,
23+
**kwargs,
24+
) -> Tuple[torch.Tensor, Dict]:
25+
if self.use_token_level_loss:
26+
sft_loss = masked_mean(-logprob, action_mask)
27+
else:
28+
sft_loss = masked_mean(-logprob, action_mask, axis=1).mean()
29+
return sft_loss, {"sft_loss": sft_loss.detach().item()}
30+
31+
@classmethod
32+
def default_args(cls):
33+
return {
34+
"use_token_level_loss": True,
35+
}

trinity/algorithm/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Common utils for algorithm module.
2+
3+
Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py
4+
"""
5+
6+
7+
def masked_sum(values, mask, axis=None):
8+
"""Compute mean of tensor with a masked values."""
9+
return (values * mask).sum(axis=axis)
10+
11+
12+
def masked_mean(values, mask, axis=None):
13+
"""Compute mean of tensor with a masked values."""
14+
return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8)

trinity/common/config.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ class AlgorithmConfig:
175175
repeat_times: int = 1
176176
gamma: Optional[float] = None
177177
lam: Optional[float] = None
178-
# TODO: add more algorithm params here
178+
179+
policy_loss_fn: str = "ppo"
180+
# If not set, use PolicyLossFn.default_args()
181+
policy_loss_fn_args: Optional[dict] = None
179182

180183

181184
@dataclass
@@ -466,6 +469,15 @@ def _check_buffer(self) -> None: # noqa: C901
466469
self.buffer.pad_token_id = 0
467470
self.buffer.tokenizer_path = self.model.model_path
468471

472+
def _check_algorithm(self) -> None:
473+
from trinity.algorithm import POLICY_LOSS_FN
474+
475+
policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn)
476+
if policy_fn_cls is None:
477+
raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}")
478+
if self.algorithm.policy_loss_fn_args is None:
479+
self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args()
480+
469481
def check_and_update(self) -> None: # noqa: C901
470482
"""Check and update the config."""
471483
self._check_deprecated()
@@ -489,6 +501,9 @@ def check_and_update(self) -> None: # noqa: C901
489501
if not self.model.critic_model_path:
490502
self.model.critic_model_path = self.model.model_path
491503

504+
# check algorithm
505+
self._check_algorithm()
506+
492507
# check explorer
493508
if (
494509
self.explorer.rollout_model.engine_type != "vllm_async"

0 commit comments

Comments
 (0)