diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py index 37f824de4f..553af6d065 100644 --- a/trinity/algorithm/advantage_fn/grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -1,35 +1,74 @@ """GRPO advantage computation -Adapted from compute_advantage_ppo in original ray_trainer.py +Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py """ +from collections import defaultdict from typing import Dict, Tuple +import torch from verl import DataProto from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn -from trinity.trainer.verl import core_algos @ADVANTAGE_FN.register_module("grpo") class GRPOAdvantageFn(AdvantageFn): """GRPO advantage computation""" - def __init__(self) -> None: - pass + def __init__( + self, + epsilon: float = 1e-6, + ) -> None: + self.epsilon = epsilon def __call__( self, exps: DataProto, **kwargs, ) -> Tuple[DataProto, Dict]: - advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=exps.batch["token_level_rewards"], - eos_mask=exps.batch["response_mask"], - index=exps.non_tensor_batch["uid"], - ) - exps.batch["advantages"] = advantages - exps.batch["returns"] = returns + """ + Compute advantage for GRPO, operating only on Outcome reward + (with only one scalar reward for each response). + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + scores: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + eos_mask = exps.batch["response_mask"] + index = exps.non_tensor_batch["uid"] + epsilon = self.epsilon + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores metrics = { # TODO: add meaningful metrics @@ -39,4 +78,6 @@ def __call__( @classmethod def default_args(cls) -> Dict: - return {} + return { + "epsilon": 1e-6, + } diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py index e9e0eb090f..b27e2c9ab0 100644 --- a/trinity/algorithm/advantage_fn/opmd_advantage.py +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -1,38 +1,84 @@ -"""OPMD advantage computation - -Adapted from compute_advantage_opmd in original ray_trainer.py -""" +"""OPMD advantage computation""" +from collections import defaultdict from typing import Dict, Tuple +import torch from verl import DataProto from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn -from trinity.trainer.verl import core_algos @ADVANTAGE_FN.register_module("opmd") class OPMDAdvantageFn(AdvantageFn): """OPMD advantage computation""" - def __init__(self) -> None: - pass + def __init__( + self, + opmd_baseline: str = "mean", + tau: float = 1.0, + ) -> None: + self.opmd_baseline = opmd_baseline + self.tau = tau def __call__( self, exps: DataProto, **kwargs, ) -> Tuple[DataProto, Dict]: - advantages, returns = core_algos.compute_opmd_outcome_advantage( - token_level_rewards=exps.batch["token_level_rewards"], - eos_mask=exps.batch["response_mask"], - # TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation - index=exps.non_tensor_batch["uid"], - opmd_baseline="mean", - tau=1.0, - ) - exps.batch["advantages"] = advantages - exps.batch["returns"] = returns + """Modified from compute_grpo_outcome_advantage + + Compute advantage for OPMD, operating only on Outcome reward + (with only one scalar reward for each response). + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + scores: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + eos_mask = exps.batch["response_mask"] + # TODO (yanxi): confirm consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation + index = exps.non_tensor_batch["uid"] + opmd_baseline = self.opmd_baseline + tau = self.tau + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2baseline = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2baseline[idx] = torch.tensor(0.0) + # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?) + elif len(id2score[idx]) > 1: + if opmd_baseline == "mean": + id2baseline[idx] = torch.mean(torch.tensor(id2score[idx])) + elif opmd_baseline == "logavgexp": + rewards_tensor = torch.tensor(id2score[idx]) + # here we use the fact that logavgexp(x) = logsumexp(x) - log(len(x)) + id2baseline[idx] = tau * ( + torch.logsumexp(rewards_tensor / tau, dim=-1) + - torch.log(torch.tensor(len(id2score[idx]))) + ) + else: + raise NotImplementedError + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2baseline[index[i]] + scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores metrics = { # TODO: add meaningful metrics @@ -42,4 +88,7 @@ def __call__( @classmethod def default_args(cls) -> Dict: - return {} + return { + "opmd_baseline": "mean", + "tau": 1.0, + } diff --git a/trinity/algorithm/advantage_fn/ppo_advantage.py b/trinity/algorithm/advantage_fn/ppo_advantage.py index 896deca116..31fda4454c 100644 --- a/trinity/algorithm/advantage_fn/ppo_advantage.py +++ b/trinity/algorithm/advantage_fn/ppo_advantage.py @@ -1,14 +1,15 @@ """PPO's GAE advantage computation -Adapted from compute_advantage_ppo in original ray_trainer.py +Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py """ from typing import Dict, Tuple +import torch from verl import DataProto from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn -from trinity.trainer.verl import core_algos +from trinity.algorithm.utils import masked_whiten @ADVANTAGE_FN.register_module("ppo") @@ -26,13 +27,48 @@ def __call__( exps: DataProto, **kwargs, ) -> Tuple[DataProto, Dict]: - advantages, returns = core_algos.compute_gae_advantage_return( - token_level_rewards=exps.batch["token_level_rewards"], - values=exps.batch["values"], - eos_mask=exps.batch["response_mask"], - gamma=self.gamma, - lam=self.lam, - ) + """ + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + values: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. + gamma: `(float)` + discounted factor used in RL + lam: `(float)` + lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + values = exps.batch["values"] + eos_mask = exps.batch["response_mask"] + gamma = self.gamma + lam = self.lam + + with torch.no_grad(): + lastgaelam = 0 + advantages_reversed = [] + gen_len = token_level_rewards.shape[-1] + + # values = values * eos_mask TODO: may use in multi-turn + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] + + lastgaelam = delta + gamma * lam * lastgaelam + # lastgaelam = torch.where( # TODO: may use in multi-turn + # eos_mask[:, t] == 1, delta + gamma * lam * lastgaelam, lastgaelam + # ) + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + + returns = advantages + values + advantages = masked_whiten(advantages, eos_mask) + exps.batch["advantages"] = advantages exps.batch["returns"] = returns diff --git a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py index d53052c83f..eb63c3605b 100644 --- a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py +++ b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py @@ -1,14 +1,15 @@ """REINFORCE++ advantage computation -Adapted from compute_advantage_ppo in original ray_trainer.py +Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py """ from typing import Dict, Tuple +import torch from verl import DataProto from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn -from trinity.trainer.verl import core_algos +from trinity.algorithm.utils import masked_whiten @ADVANTAGE_FN.register_module("reinforceplusplus") @@ -21,11 +22,34 @@ def __call__( exps: DataProto, **kwargs, ) -> Tuple[DataProto, Dict]: - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=exps.batch["token_level_rewards"], - eos_mask=exps.batch["response_mask"], - gamma=self.gamma, - ) + """ + Compute advantage for REINFORCE++. + This implementation is based on the paper: https://arxiv.org/abs/2501.03262 + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + eos_mask = exps.batch["response_mask"] + gamma = self.gamma + + with torch.no_grad(): + returns = torch.zeros_like(token_level_rewards) + running_return = 0 + + for t in reversed(range(token_level_rewards.shape[1])): + running_return = token_level_rewards[:, t] + gamma * running_return + returns[:, t] = running_return + + advantages = masked_whiten(returns, eos_mask) + advantages = advantages * eos_mask + exps.batch["advantages"] = advantages exps.batch["returns"] = returns diff --git a/trinity/algorithm/advantage_fn/remax_advantage.py b/trinity/algorithm/advantage_fn/remax_advantage.py index 516213c0c2..07f92d91a0 100644 --- a/trinity/algorithm/advantage_fn/remax_advantage.py +++ b/trinity/algorithm/advantage_fn/remax_advantage.py @@ -1,14 +1,14 @@ """REMAX advantage computation -Adapted from compute_advantage_ppo in original ray_trainer.py +Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py """ from typing import Dict, Tuple +import torch from verl import DataProto from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn -from trinity.trainer.verl import core_algos @ADVANTAGE_FN.register_module("remax") @@ -21,11 +21,37 @@ def __call__( exps: DataProto, **kwargs, ) -> Tuple[DataProto, Dict]: - advantages, returns = core_algos.compute_remax_outcome_advantage( - token_level_rewards=exps.batch["token_level_rewards"], - reward_baselines=exps.batch["reward_baselines"], - eos_mask=exps.batch["response_mask"], - ) + """ + Compute advantage for ReMax, operating only on Outcome reward + (with only one scalar reward for each response). + This implementation is based on the paper: https://arxiv.org/abs/2310.10505 + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + reward_baselines: `(torch.Tensor)` + shape: (bs,) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + reward_baselines = exps.batch["reward_baselines"] + eos_mask = exps.batch["response_mask"] + + response_length = token_level_rewards.shape[-1] + token_level_rewards.sum(dim=-1) + + with torch.no_grad(): + returns = ( + (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + ) + advantages = ( + returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask + ) + exps.batch["advantages"] = advantages exps.batch["returns"] = returns diff --git a/trinity/algorithm/advantage_fn/rloo_advantage.py b/trinity/algorithm/advantage_fn/rloo_advantage.py index c88276e836..fb2680a68b 100644 --- a/trinity/algorithm/advantage_fn/rloo_advantage.py +++ b/trinity/algorithm/advantage_fn/rloo_advantage.py @@ -1,14 +1,15 @@ """RLOO advantage computation -Adapted from compute_advantage_ppo in original ray_trainer.py +Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py """ +from collections import defaultdict from typing import Dict, Tuple +import torch from verl import DataProto from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn -from trinity.trainer.verl import core_algos @ADVANTAGE_FN.register_module("rloo") @@ -21,13 +22,47 @@ def __call__( exps: DataProto, **kwargs, ) -> Tuple[DataProto, Dict]: - advantages, returns = core_algos.compute_rloo_outcome_advantage( - token_level_rewards=exps.batch["token_level_rewards"], - eos_mask=exps.batch["response_mask"], - index=exps.non_tensor_batch["uid"], - ) - exps.batch["advantages"] = advantages - exps.batch["returns"] = returns + """ + Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + scores: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + eos_mask = exps.batch["response_mask"] + index = exps.non_tensor_batch["uid"] + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + response_num = len(id2score[index[i]]) + if response_num > 1: + scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[ + index[i] + ] * response_num / (response_num - 1) + scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores metrics = { # TODO: add meaningful metrics diff --git a/trinity/algorithm/kl_fn/kl_fn.py b/trinity/algorithm/kl_fn/kl_fn.py index 3901ea7f3c..95d2915a84 100644 --- a/trinity/algorithm/kl_fn/kl_fn.py +++ b/trinity/algorithm/kl_fn/kl_fn.py @@ -1,3 +1,11 @@ +"""KL penalty and loss. + +Ref: +https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py +https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py +https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/utils.py +""" + from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Tuple @@ -11,7 +19,7 @@ class KLFn(ABC): """ - KL controller. + KL penalty and loss. """ def __init__( diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py index e9457c55d1..042d26b341 100644 --- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py @@ -1,7 +1,4 @@ -"""PPO policy loss function. - -Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py -""" +"""OPMD policy loss function.""" from typing import Dict, List, Tuple diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py index 01356cc066..8660a6376c 100644 --- a/trinity/algorithm/utils.py +++ b/trinity/algorithm/utils.py @@ -3,6 +3,8 @@ Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py """ +import torch + def masked_sum(values, mask, axis=None): """Compute mean of tensor with a masked values.""" @@ -14,6 +16,44 @@ def masked_mean(values, mask, axis=None): return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8) +def masked_var(values, mask, unbiased=True): + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError("At least one element in the mask has to be 1.") + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + if mask_sum == 1: + raise ValueError("The sum of the mask is one, which can cause a division by zero.") + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values, mask, shift_mean=True): + """ + Whiten `values` by normalizing with mean and variance computed over `mask`. + + Args: + values (torch.Tensor): Input tensor. + mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats. + shift_mean (bool): If True (default), output is zero-mean; + if False, the original mean is re-added after scaling. + + Returns: + torch.Tensor: Whitened tensor of same shape as `values`. + """ + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) -> dict: if dst_metrics is None: dst_metrics = {}