Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 53 additions & 12 deletions trinity/algorithm/advantage_fn/grpo_advantage.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,74 @@
"""GRPO advantage computation

Adapted from compute_advantage_ppo in original ray_trainer.py
Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""

from collections import defaultdict
from typing import Dict, Tuple

import torch
from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos


@ADVANTAGE_FN.register_module("grpo")
class GRPOAdvantageFn(AdvantageFn):
"""GRPO advantage computation"""

def __init__(self) -> None:
pass
def __init__(
self,
epsilon: float = 1e-6,
) -> None:
self.epsilon = epsilon

def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=exps.batch["token_level_rewards"],
eos_mask=exps.batch["response_mask"],
index=exps.non_tensor_batch["uid"],
)
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).

token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
scores: `(torch.Tensor)`
shape: (bs, response_length)
"""
token_level_rewards = exps.batch["token_level_rewards"]
eos_mask = exps.batch["response_mask"]
index = exps.non_tensor_batch["uid"]
epsilon = self.epsilon

response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2mean = {}
id2std = {}

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask

exps.batch["advantages"] = scores
exps.batch["returns"] = scores

metrics = {
# TODO: add meaningful metrics
Expand All @@ -39,4 +78,6 @@ def __call__(

@classmethod
def default_args(cls) -> Dict:
return {}
return {
"epsilon": 1e-6,
}
85 changes: 67 additions & 18 deletions trinity/algorithm/advantage_fn/opmd_advantage.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,84 @@
"""OPMD advantage computation

Adapted from compute_advantage_opmd in original ray_trainer.py
"""
"""OPMD advantage computation"""

from collections import defaultdict
from typing import Dict, Tuple

import torch
from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos


@ADVANTAGE_FN.register_module("opmd")
class OPMDAdvantageFn(AdvantageFn):
"""OPMD advantage computation"""

def __init__(self) -> None:
pass
def __init__(
self,
opmd_baseline: str = "mean",
tau: float = 1.0,
) -> None:
self.opmd_baseline = opmd_baseline
self.tau = tau

def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_opmd_outcome_advantage(
token_level_rewards=exps.batch["token_level_rewards"],
eos_mask=exps.batch["response_mask"],
# TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
index=exps.non_tensor_batch["uid"],
opmd_baseline="mean",
tau=1.0,
)
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns
"""Modified from compute_grpo_outcome_advantage

Compute advantage for OPMD, operating only on Outcome reward
(with only one scalar reward for each response).

token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
scores: `(torch.Tensor)`
shape: (bs, response_length)
"""
token_level_rewards = exps.batch["token_level_rewards"]
eos_mask = exps.batch["response_mask"]
# TODO (yanxi): confirm consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
index = exps.non_tensor_batch["uid"]
opmd_baseline = self.opmd_baseline
tau = self.tau

response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2baseline = {}

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2baseline[idx] = torch.tensor(0.0)
# TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?)
elif len(id2score[idx]) > 1:
if opmd_baseline == "mean":
id2baseline[idx] = torch.mean(torch.tensor(id2score[idx]))
elif opmd_baseline == "logavgexp":
rewards_tensor = torch.tensor(id2score[idx])
# here we use the fact that logavgexp(x) = logsumexp(x) - log(len(x))
id2baseline[idx] = tau * (
torch.logsumexp(rewards_tensor / tau, dim=-1)
- torch.log(torch.tensor(len(id2score[idx])))
)
else:
raise NotImplementedError
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = scores[i] - id2baseline[index[i]]
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask

exps.batch["advantages"] = scores
exps.batch["returns"] = scores

metrics = {
# TODO: add meaningful metrics
Expand All @@ -42,4 +88,7 @@ def __call__(

@classmethod
def default_args(cls) -> Dict:
return {}
return {
"opmd_baseline": "mean",
"tau": 1.0,
}
54 changes: 45 additions & 9 deletions trinity/algorithm/advantage_fn/ppo_advantage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""PPO's GAE advantage computation

Adapted from compute_advantage_ppo in original ray_trainer.py
Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""

from typing import Dict, Tuple

import torch
from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos
from trinity.algorithm.utils import masked_whiten


@ADVANTAGE_FN.register_module("ppo")
Expand All @@ -26,13 +27,48 @@ def __call__(
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_gae_advantage_return(
token_level_rewards=exps.batch["token_level_rewards"],
values=exps.batch["values"],
eos_mask=exps.batch["response_mask"],
gamma=self.gamma,
lam=self.lam,
)
"""
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
values: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma: `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
token_level_rewards = exps.batch["token_level_rewards"]
values = exps.batch["values"]
eos_mask = exps.batch["response_mask"]
gamma = self.gamma
lam = self.lam

with torch.no_grad():
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]

# values = values * eos_mask TODO: may use in multi-turn
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]

lastgaelam = delta + gamma * lam * lastgaelam
# lastgaelam = torch.where( # TODO: may use in multi-turn
# eos_mask[:, t] == 1, delta + gamma * lam * lastgaelam, lastgaelam
# )
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)

returns = advantages + values
advantages = masked_whiten(advantages, eos_mask)

exps.batch["advantages"] = advantages
exps.batch["returns"] = returns

Expand Down
38 changes: 31 additions & 7 deletions trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""REINFORCE++ advantage computation

Adapted from compute_advantage_ppo in original ray_trainer.py
Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""

from typing import Dict, Tuple

import torch
from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos
from trinity.algorithm.utils import masked_whiten


@ADVANTAGE_FN.register_module("reinforceplusplus")
Expand All @@ -21,11 +22,34 @@ def __call__(
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards=exps.batch["token_level_rewards"],
eos_mask=exps.batch["response_mask"],
gamma=self.gamma,
)
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262

token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
token_level_rewards = exps.batch["token_level_rewards"]
eos_mask = exps.batch["response_mask"]
gamma = self.gamma

with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0

for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return

advantages = masked_whiten(returns, eos_mask)
advantages = advantages * eos_mask

exps.batch["advantages"] = advantages
exps.batch["returns"] = returns

Expand Down
40 changes: 33 additions & 7 deletions trinity/algorithm/advantage_fn/remax_advantage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""REMAX advantage computation

Adapted from compute_advantage_ppo in original ray_trainer.py
Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""

from typing import Dict, Tuple

import torch
from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos


@ADVANTAGE_FN.register_module("remax")
Expand All @@ -21,11 +21,37 @@ def __call__(
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_remax_outcome_advantage(
token_level_rewards=exps.batch["token_level_rewards"],
reward_baselines=exps.batch["reward_baselines"],
eos_mask=exps.batch["response_mask"],
)
"""
Compute advantage for ReMax, operating only on Outcome reward
(with only one scalar reward for each response).
This implementation is based on the paper: https://arxiv.org/abs/2310.10505

token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
reward_baselines: `(torch.Tensor)`
shape: (bs,)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
token_level_rewards = exps.batch["token_level_rewards"]
reward_baselines = exps.batch["reward_baselines"]
eos_mask = exps.batch["response_mask"]

response_length = token_level_rewards.shape[-1]
token_level_rewards.sum(dim=-1)

with torch.no_grad():
returns = (
(token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
)
advantages = (
returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
)

exps.batch["advantages"] = advantages
exps.batch["returns"] = returns

Expand Down
Loading