|
| 1 | +from typing import Tuple |
| 2 | +from collections import namedtuple |
| 3 | +import torch |
| 4 | +from .log_prob_utils import efficient_method, naive_method, less_efficient_method, LogProbFunction |
| 5 | + |
| 6 | +grpo_policy_data = namedtuple('grpo_policy_data', ['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight']) |
| 7 | +grpo_info = namedtuple('grpo_info', ['approx_kl', 'clipfrac']) |
| 8 | + |
| 9 | + |
| 10 | +def grpo_policy_error( |
| 11 | + data: namedtuple, |
| 12 | + log_prob_fn: LogProbFunction = efficient_method, # Method to calculate the log probabilities |
| 13 | + clip_ratio: float = 0.2, |
| 14 | + beta: float = 0.1 # Weight coefficient for KL divergence |
| 15 | +) -> Tuple[namedtuple, namedtuple]: |
| 16 | + """ |
| 17 | + Overview: |
| 18 | + Group Relative Policy Optimization( arxiv:2402.03300) . |
| 19 | + Arguments: |
| 20 | + - data (:obj:`namedtuple`): the grpo input data with fields shown in ``grpo_policy_data``. |
| 21 | + - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2. |
| 22 | + - beta (:obj:`float`): weight coefficient for KL divergence regularization, defaults to 0.1. |
| 23 | + - log_prob_fn (:obj:`LogProbFunction`): The method to calculate the log probabilities, \ |
| 24 | + defaults to `efficient_method`. |
| 25 | + Returns: |
| 26 | + - loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor. |
| 27 | + - grpo_info (:obj:`namedtuple`): the grpo optim information for monitoring, all of them are Python scalar. |
| 28 | + Shapes: |
| 29 | + - logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length, \ |
| 30 | + and V is vocabulary size. |
| 31 | + - logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`. |
| 32 | + - logit_ref (:obj:`torch.FloatTensor`): :math:`(B, S, V)`. |
| 33 | + - action (:obj:`torch.LongTensor`): :math:`(B, S)`. |
| 34 | + - adv (:obj:`torch.FloatTensor`): :math:`(B, )`. |
| 35 | + - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, S)`. |
| 36 | + - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor. |
| 37 | + - mean_kl (:obj:`float`): mean KL divergence between current and reference policy. |
| 38 | + - mean_ratio (:obj:`float`): mean probability ratio. |
| 39 | + - mean_clipped (:obj:`float`): proportion of clipped probability ratios. |
| 40 | + """ |
| 41 | + |
| 42 | + # Calculate log probabilities for selected token |
| 43 | + per_token_logps = log_prob_fn(data.logit_new, data.action) |
| 44 | + per_token_ref_logps = log_prob_fn(data.logit_ref, data.action) |
| 45 | + per_token_old_logps = log_prob_fn(data.logit_old, data.action) |
| 46 | + |
| 47 | + # Calculate KL divergence: exp(q-p) - (q-p) - 1, |
| 48 | + # where p is current policy and q is reference policy |
| 49 | + per_token_kl = (torch.exp(per_token_ref_logps - per_token_logps) - (per_token_ref_logps - per_token_logps) - 1) |
| 50 | + |
| 51 | + # Calculate policy ratio |
| 52 | + ratio = torch.exp(per_token_logps - per_token_old_logps) |
| 53 | + ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) |
| 54 | + |
| 55 | + # Calculate loss for each token |
| 56 | + advantages = data.adv.unsqueeze(1) # [B, 1] |
| 57 | + per_token_loss_unclipped = ratio * advantages |
| 58 | + per_token_loss_clipped = ratio_clipped * advantages |
| 59 | + per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped) |
| 60 | + |
| 61 | + # Add KL divergence regularization term |
| 62 | + per_token_loss = per_token_loss + beta * per_token_kl |
| 63 | + |
| 64 | + # Calculate average loss using weight mask |
| 65 | + weight = data.weight if data.weight is not None \ |
| 66 | + else torch.ones_like(per_token_loss) |
| 67 | + loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean() |
| 68 | + |
| 69 | + # Calculate additional metrics |
| 70 | + with torch.no_grad(): |
| 71 | + approx_kl = (per_token_old_logps - per_token_logps).mean().item() |
| 72 | + clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) |
| 73 | + clipfrac = torch.as_tensor(clipped).float().mean().item() |
| 74 | + |
| 75 | + return loss, grpo_info(approx_kl=approx_kl, clipfrac=clipfrac) |
0 commit comments