Skip to content

Commit 8f48cb1

Browse files
authored
feature(nyz&dcy): add LLM/VLM RLHF loss (PPO/GRPO/RLOO) (#857)
* test(nyz): polish ppo and add rlhf ppo loss test * interface(nyz): add naive interface about grpo/rloo * test&implement(dcy): add unit tests for GRPO and RLOO - Add test_grpo_rlhf.py for GRPO unit tests - Add test_rloo_rlhf.py for RLOO unit tests - Update GRPO implementation - Update RLOO implementation * polish(dcy): polish grpo and rloo and test unit * (dcy) rloo and grpo * (dcy) redesign avd from reward * (dcy) Polish style:Use selective log-softmax to reduce peak vram consumption * (dcy)small changes * (dcy)git add readme and typing * (dcy) English comment file name and function name changed
1 parent abcf972 commit 8f48cb1

File tree

11 files changed

+943
-14
lines changed

11 files changed

+943
-14
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,3 +1429,8 @@ collect_demo_data_config.py
14291429
events.*
14301430

14311431
evogym/*
1432+
ding/example/*
1433+
ding/framework/middleware/tests/wandb/
1434+
ding/.style.yapf
1435+
ding/format.sh
1436+
ding/framework/middleware_v3/

ding/rl_utils/grpo.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Tuple
2+
from collections import namedtuple
3+
import torch
4+
from .log_prob_utils import efficient_method, naive_method, less_efficient_method, LogProbFunction
5+
6+
grpo_policy_data = namedtuple('grpo_policy_data', ['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight'])
7+
grpo_info = namedtuple('grpo_info', ['approx_kl', 'clipfrac'])
8+
9+
10+
def grpo_policy_error(
11+
data: namedtuple,
12+
log_prob_fn: LogProbFunction = efficient_method, # Method to calculate the log probabilities
13+
clip_ratio: float = 0.2,
14+
beta: float = 0.1 # Weight coefficient for KL divergence
15+
) -> Tuple[namedtuple, namedtuple]:
16+
"""
17+
Overview:
18+
Group Relative Policy Optimization( arxiv:2402.03300) .
19+
Arguments:
20+
- data (:obj:`namedtuple`): the grpo input data with fields shown in ``grpo_policy_data``.
21+
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2.
22+
- beta (:obj:`float`): weight coefficient for KL divergence regularization, defaults to 0.1.
23+
- log_prob_fn (:obj:`LogProbFunction`): The method to calculate the log probabilities, \
24+
defaults to `efficient_method`.
25+
Returns:
26+
- loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor.
27+
- grpo_info (:obj:`namedtuple`): the grpo optim information for monitoring, all of them are Python scalar.
28+
Shapes:
29+
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length, \
30+
and V is vocabulary size.
31+
- logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
32+
- logit_ref (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
33+
- action (:obj:`torch.LongTensor`): :math:`(B, S)`.
34+
- adv (:obj:`torch.FloatTensor`): :math:`(B, )`.
35+
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, S)`.
36+
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor.
37+
- mean_kl (:obj:`float`): mean KL divergence between current and reference policy.
38+
- mean_ratio (:obj:`float`): mean probability ratio.
39+
- mean_clipped (:obj:`float`): proportion of clipped probability ratios.
40+
"""
41+
42+
# Calculate log probabilities for selected token
43+
per_token_logps = log_prob_fn(data.logit_new, data.action)
44+
per_token_ref_logps = log_prob_fn(data.logit_ref, data.action)
45+
per_token_old_logps = log_prob_fn(data.logit_old, data.action)
46+
47+
# Calculate KL divergence: exp(q-p) - (q-p) - 1,
48+
# where p is current policy and q is reference policy
49+
per_token_kl = (torch.exp(per_token_ref_logps - per_token_logps) - (per_token_ref_logps - per_token_logps) - 1)
50+
51+
# Calculate policy ratio
52+
ratio = torch.exp(per_token_logps - per_token_old_logps)
53+
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)
54+
55+
# Calculate loss for each token
56+
advantages = data.adv.unsqueeze(1) # [B, 1]
57+
per_token_loss_unclipped = ratio * advantages
58+
per_token_loss_clipped = ratio_clipped * advantages
59+
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)
60+
61+
# Add KL divergence regularization term
62+
per_token_loss = per_token_loss + beta * per_token_kl
63+
64+
# Calculate average loss using weight mask
65+
weight = data.weight if data.weight is not None \
66+
else torch.ones_like(per_token_loss)
67+
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()
68+
69+
# Calculate additional metrics
70+
with torch.no_grad():
71+
approx_kl = (per_token_old_logps - per_token_logps).mean().item()
72+
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
73+
clipfrac = torch.as_tensor(clipped).float().mean().item()
74+
75+
return loss, grpo_info(approx_kl=approx_kl, clipfrac=clipfrac)

ding/rl_utils/log_prob_utils.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from typing import List, Callable, Optional, Any
2+
import torch
3+
from torch import Tensor
4+
5+
LogitsProcessor = Callable[[Tensor, Tensor], Tensor]
6+
7+
8+
def naive_method(logits: Tensor, index: Tensor) -> Tensor:
9+
"""Calculate per-token log probabilities using naive method.
10+
11+
Args:
12+
logits: Token logits of shape [B, S, V] or [S, V] where:
13+
B = batch size
14+
S = sequence length
15+
V = vocabulary size
16+
index: Selected token indices of shape [B, S] or [S]
17+
18+
Returns:
19+
Tensor: Log probabilities for selected tokens of shape [B, S] or [S]
20+
"""
21+
# Calculate log probabilities for each token
22+
log_prob_new: Tensor = torch.log_softmax(logits, dim=-1)
23+
# Get log probabilities for selected actions
24+
index = index.unsqueeze(-1) # [B, S, 1] or [S, 1]
25+
per_token_logps: Tensor = torch.gather(log_prob_new, -1, index).squeeze(-1)
26+
return per_token_logps
27+
28+
29+
def efficient_method(logits: Tensor, index: Tensor) -> Tensor:
30+
"""Calculate per-token log probabilities efficiently.
31+
32+
Args:
33+
logits: Token logits of shape [B, S, V] or [S, V] where:
34+
B = batch size
35+
S = sequence length
36+
V = vocabulary size
37+
index: Selected token indices of shape [B, S] or [S]
38+
39+
Returns:
40+
Tensor: Log probabilities for selected tokens of shape [B, S] or [S]
41+
"""
42+
if logits.dtype in [torch.float32, torch.float64]:
43+
selected_logits: Tensor = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
44+
45+
# Loop to reduce peak mem consumption
46+
logsumexp_values: Tensor = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
47+
48+
# log_softmax(x_i) = x_i - logsumexp(x)
49+
per_token_logps: Tensor = selected_logits - logsumexp_values
50+
else:
51+
# logsumexp approach is unstable with bfloat16
52+
per_token_logps: List[Tensor] = []
53+
54+
# Loop to reduce peak mem consumption
55+
for row_logits, row_labels in zip(logits, index): # Iterate over sequence length
56+
row_logps: Tensor = torch.log_softmax(row_logits, dim=-1)
57+
row_per_token_logps: Tensor = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
58+
per_token_logps.append(row_per_token_logps)
59+
60+
per_token_logps = torch.stack(per_token_logps)
61+
62+
return per_token_logps
63+
64+
65+
def less_efficient_method(logits: Tensor, index: Tensor) -> Tensor:
66+
"""Calculate per-token log probabilities using categorical distribution.
67+
68+
Args:
69+
logits: Token logits of shape [B, S, V] or [S, V] where:
70+
B = batch size
71+
S = sequence length
72+
V = vocabulary size
73+
index: Selected token indices of shape [B, S] or [S]
74+
75+
Returns:
76+
Tensor: Log probabilities for selected tokens of shape [B, S] or [S]
77+
"""
78+
dist = torch.distributions.categorical.Categorical(logits=logits)
79+
logp: Tensor = dist.log_prob(index)
80+
return logp
81+
82+
83+
# 定义一个统一的类型
84+
LogProbFunction = Callable[[Tensor, Tensor], Tensor]
85+
86+
# 导出所有方法
87+
__all__ = ['naive_method', 'efficient_method', 'less_efficient_method', 'LogProbFunction']

ding/rl_utils/ppo.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,21 @@ def ppo_error(
104104
return ppo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info
105105

106106

107-
def ppo_policy_error(data: namedtuple,
108-
clip_ratio: float = 0.2,
109-
dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
110-
'''
107+
def ppo_policy_error(
108+
data: namedtuple,
109+
clip_ratio: float = 0.2,
110+
dual_clip: Optional[float] = None,
111+
entropy_bonus: bool = True
112+
) -> Tuple[namedtuple, namedtuple]:
113+
"""
111114
Overview:
112-
Get PPO policy loss
115+
Get PPO policy loss (both for classical RL in control/video games and LLM/VLM RLHF).
113116
Arguments:
114-
- data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_policy_data``
115-
- clip_ratio (:obj:`float`): clip value for ratio
116-
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
117-
defaults to 5.0, if you don't want to use it, set this parameter to None
117+
- data (:obj:`namedtuple`): Ppo input data with fieids shown in ``ppo_policy_data``.
118+
- clip_ratio (:obj:`float`): Clip value for ratio, defaults to 0.2.
119+
- dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \
120+
defaults to 5.0, if you don't want to use it, set this parameter to None
121+
- entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
118122
Returns:
119123
- ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
120124
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -136,18 +140,29 @@ def ppo_policy_error(data: namedtuple,
136140
>>> weight=torch.ones(3),
137141
>>> )
138142
>>> loss, info = ppo_policy_error(data)
139-
'''
143+
144+
.. note::
145+
This function can be extended from `B` to more parallel dimensions, like `(B, S)`, where `S` is the
146+
sequence length in LLM/VLM.
147+
148+
.. note::
149+
For the action mask often used in LLM/VLM, users can set the `weight` to the action mask.
150+
"""
140151
logit_new, logit_old, action, adv, weight = data
141152
if weight is None:
142153
weight = torch.ones_like(adv)
143154
dist_new = torch.distributions.categorical.Categorical(logits=logit_new)
144155
dist_old = torch.distributions.categorical.Categorical(logits=logit_old)
145156
logp_new = dist_new.log_prob(action)
146157
logp_old = dist_old.log_prob(action)
147-
dist_new_entropy = dist_new.entropy()
148-
if dist_new_entropy.shape != weight.shape:
149-
dist_new_entropy = dist_new.entropy().mean(dim=1)
150-
entropy_loss = (dist_new_entropy * weight).mean()
158+
159+
if entropy_bonus:
160+
dist_new_entropy = dist_new.entropy()
161+
if dist_new_entropy.shape != weight.shape: # for the multi-agent rl case
162+
dist_new_entropy = dist_new.entropy().mean(dim=1)
163+
entropy_loss = (dist_new_entropy * weight).mean()
164+
else:
165+
entropy_loss = torch.tensor(0.0)
151166
# policy_loss
152167
ratio = torch.exp(logp_new - logp_old)
153168
if ratio.shape != adv.shape:

ding/rl_utils/rloo.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from typing import Tuple
2+
from collections import namedtuple
3+
import torch
4+
from .log_prob_utils import efficient_method, naive_method, less_efficient_method, LogProbFunction
5+
6+
rloo_policy_data = namedtuple('rloo_policy_data', ['logit_new', 'logit_old', 'action', 'reward', 'weight'])
7+
rloo_info = namedtuple('rloo_info', ['approx_kl', 'clipfrac'])
8+
9+
10+
def rloo_policy_error(
11+
data: namedtuple,
12+
log_prob_fn: LogProbFunction = efficient_method, # Method to calculate the log probabilities
13+
clip_ratio: float = 0.2,
14+
) -> Tuple[namedtuple, namedtuple]:
15+
"""
16+
Overview:
17+
REINFORCE Leave-One-Out(arXiv:2402.14740)
18+
Arguments:
19+
- data (:obj:`namedtuple`): the rloo input data with fields shown in ``rloo_policy_data``.
20+
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2.
21+
- log_prob_fn (:obj:`LogProbFunction`): The method to calculate the log probabilities, \
22+
defaults to `efficient_method`.
23+
Returns:
24+
- loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor.
25+
- rloo_info (:obj:`namedtuple`): the rloo optim information for monitoring, all of them are Python scalar.
26+
Shapes:
27+
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length,\
28+
and V is vocabulary size.
29+
- logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
30+
- action (:obj:`torch.LongTensor`): :math:`(B, S)`.
31+
- reward (:obj:`torch.FloatTensor`): :math:`(K, B)`, where K is the number of samples per prompt.
32+
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, S)`.
33+
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor.
34+
- mean_ratio (:obj:`float`): mean probability ratio.
35+
- mean_clipped (:obj:`float`): proportion of clipped probability ratios.
36+
- mean_advantage (:obj:`float`): mean advantage value.
37+
"""
38+
39+
# Calculate advantage of each action
40+
rloo_k = data.reward.size(0)
41+
baseline = (data.reward.sum(0) - data.reward) / (rloo_k - 1)
42+
adv = data.reward - baseline
43+
adv = adv.flatten()
44+
45+
# Get log probabilities for selected actions
46+
per_token_logps = log_prob_fn(data.logit_new, data.action)
47+
per_token_old_logps = log_prob_fn(data.logit_old, data.action)
48+
49+
# Calculate policy ratio
50+
ratio = torch.exp(per_token_logps - per_token_old_logps)
51+
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)
52+
53+
# Calculate loss for each token
54+
advantages = adv.unsqueeze(1) # [B, 1]
55+
per_token_loss_unclipped = ratio * advantages
56+
per_token_loss_clipped = ratio_clipped * advantages
57+
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)
58+
59+
# Calculate average loss using weight mask
60+
weight = data.weight if data.weight is not None else (torch.ones_like(per_token_loss))
61+
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()
62+
63+
# Calculate additional metrics
64+
with torch.no_grad():
65+
approx_kl = (per_token_old_logps - per_token_logps).mean().item()
66+
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
67+
clipfrac = torch.as_tensor(clipped).float().mean().item()
68+
69+
return loss, rloo_info(approx_kl=approx_kl, clipfrac=clipfrac)

0 commit comments

Comments
 (0)