Skip to content

Commit 2d8f0c1

Browse files
authored
Refactor advantage computation (cont.) (#68)
1 parent 732d801 commit 2d8f0c1

File tree

9 files changed

+323
-67
lines changed

9 files changed

+323
-67
lines changed
Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,74 @@
11
"""GRPO advantage computation
22
3-
Adapted from compute_advantage_ppo in original ray_trainer.py
3+
Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
44
"""
55

6+
from collections import defaultdict
67
from typing import Dict, Tuple
78

9+
import torch
810
from verl import DataProto
911

1012
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11-
from trinity.trainer.verl import core_algos
1213

1314

1415
@ADVANTAGE_FN.register_module("grpo")
1516
class GRPOAdvantageFn(AdvantageFn):
1617
"""GRPO advantage computation"""
1718

18-
def __init__(self) -> None:
19-
pass
19+
def __init__(
20+
self,
21+
epsilon: float = 1e-6,
22+
) -> None:
23+
self.epsilon = epsilon
2024

2125
def __call__(
2226
self,
2327
exps: DataProto,
2428
**kwargs,
2529
) -> Tuple[DataProto, Dict]:
26-
advantages, returns = core_algos.compute_grpo_outcome_advantage(
27-
token_level_rewards=exps.batch["token_level_rewards"],
28-
eos_mask=exps.batch["response_mask"],
29-
index=exps.non_tensor_batch["uid"],
30-
)
31-
exps.batch["advantages"] = advantages
32-
exps.batch["returns"] = returns
30+
"""
31+
Compute advantage for GRPO, operating only on Outcome reward
32+
(with only one scalar reward for each response).
33+
34+
token_level_rewards: `(torch.Tensor)`
35+
shape: (bs, response_length)
36+
eos_mask: `(torch.Tensor)`
37+
shape: (bs, response_length)
38+
scores: `(torch.Tensor)`
39+
shape: (bs, response_length)
40+
"""
41+
token_level_rewards = exps.batch["token_level_rewards"]
42+
eos_mask = exps.batch["response_mask"]
43+
index = exps.non_tensor_batch["uid"]
44+
epsilon = self.epsilon
45+
46+
response_length = token_level_rewards.shape[-1]
47+
scores = token_level_rewards.sum(dim=-1)
48+
49+
id2score = defaultdict(list)
50+
id2mean = {}
51+
id2std = {}
52+
53+
with torch.no_grad():
54+
bsz = scores.shape[0]
55+
for i in range(bsz):
56+
id2score[index[i]].append(scores[i])
57+
for idx in id2score:
58+
if len(id2score[idx]) == 1:
59+
id2mean[idx] = torch.tensor(0.0)
60+
id2std[idx] = torch.tensor(1.0)
61+
elif len(id2score[idx]) > 1:
62+
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
63+
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
64+
else:
65+
raise ValueError(f"no score in prompt index: {idx}")
66+
for i in range(bsz):
67+
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
68+
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
69+
70+
exps.batch["advantages"] = scores
71+
exps.batch["returns"] = scores
3372

3473
metrics = {
3574
# TODO: add meaningful metrics
@@ -39,4 +78,6 @@ def __call__(
3978

4079
@classmethod
4180
def default_args(cls) -> Dict:
42-
return {}
81+
return {
82+
"epsilon": 1e-6,
83+
}
Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,84 @@
1-
"""OPMD advantage computation
2-
3-
Adapted from compute_advantage_opmd in original ray_trainer.py
4-
"""
1+
"""OPMD advantage computation"""
52

3+
from collections import defaultdict
64
from typing import Dict, Tuple
75

6+
import torch
87
from verl import DataProto
98

109
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11-
from trinity.trainer.verl import core_algos
1210

1311

1412
@ADVANTAGE_FN.register_module("opmd")
1513
class OPMDAdvantageFn(AdvantageFn):
1614
"""OPMD advantage computation"""
1715

18-
def __init__(self) -> None:
19-
pass
16+
def __init__(
17+
self,
18+
opmd_baseline: str = "mean",
19+
tau: float = 1.0,
20+
) -> None:
21+
self.opmd_baseline = opmd_baseline
22+
self.tau = tau
2023

2124
def __call__(
2225
self,
2326
exps: DataProto,
2427
**kwargs,
2528
) -> Tuple[DataProto, Dict]:
26-
advantages, returns = core_algos.compute_opmd_outcome_advantage(
27-
token_level_rewards=exps.batch["token_level_rewards"],
28-
eos_mask=exps.batch["response_mask"],
29-
# TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
30-
index=exps.non_tensor_batch["uid"],
31-
opmd_baseline="mean",
32-
tau=1.0,
33-
)
34-
exps.batch["advantages"] = advantages
35-
exps.batch["returns"] = returns
29+
"""Modified from compute_grpo_outcome_advantage
30+
31+
Compute advantage for OPMD, operating only on Outcome reward
32+
(with only one scalar reward for each response).
33+
34+
token_level_rewards: `(torch.Tensor)`
35+
shape: (bs, response_length)
36+
eos_mask: `(torch.Tensor)`
37+
shape: (bs, response_length)
38+
scores: `(torch.Tensor)`
39+
shape: (bs, response_length)
40+
"""
41+
token_level_rewards = exps.batch["token_level_rewards"]
42+
eos_mask = exps.batch["response_mask"]
43+
# TODO (yanxi): confirm consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
44+
index = exps.non_tensor_batch["uid"]
45+
opmd_baseline = self.opmd_baseline
46+
tau = self.tau
47+
48+
response_length = token_level_rewards.shape[-1]
49+
scores = token_level_rewards.sum(dim=-1)
50+
51+
id2score = defaultdict(list)
52+
id2baseline = {}
53+
54+
with torch.no_grad():
55+
bsz = scores.shape[0]
56+
for i in range(bsz):
57+
id2score[index[i]].append(scores[i])
58+
for idx in id2score:
59+
if len(id2score[idx]) == 1:
60+
id2baseline[idx] = torch.tensor(0.0)
61+
# TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?)
62+
elif len(id2score[idx]) > 1:
63+
if opmd_baseline == "mean":
64+
id2baseline[idx] = torch.mean(torch.tensor(id2score[idx]))
65+
elif opmd_baseline == "logavgexp":
66+
rewards_tensor = torch.tensor(id2score[idx])
67+
# here we use the fact that logavgexp(x) = logsumexp(x) - log(len(x))
68+
id2baseline[idx] = tau * (
69+
torch.logsumexp(rewards_tensor / tau, dim=-1)
70+
- torch.log(torch.tensor(len(id2score[idx])))
71+
)
72+
else:
73+
raise NotImplementedError
74+
else:
75+
raise ValueError(f"no score in prompt index: {idx}")
76+
for i in range(bsz):
77+
scores[i] = scores[i] - id2baseline[index[i]]
78+
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
79+
80+
exps.batch["advantages"] = scores
81+
exps.batch["returns"] = scores
3682

3783
metrics = {
3884
# TODO: add meaningful metrics
@@ -42,4 +88,7 @@ def __call__(
4288

4389
@classmethod
4490
def default_args(cls) -> Dict:
45-
return {}
91+
return {
92+
"opmd_baseline": "mean",
93+
"tau": 1.0,
94+
}

trinity/algorithm/advantage_fn/ppo_advantage.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""PPO's GAE advantage computation
22
3-
Adapted from compute_advantage_ppo in original ray_trainer.py
3+
Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
44
"""
55

66
from typing import Dict, Tuple
77

8+
import torch
89
from verl import DataProto
910

1011
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11-
from trinity.trainer.verl import core_algos
12+
from trinity.algorithm.utils import masked_whiten
1213

1314

1415
@ADVANTAGE_FN.register_module("ppo")
@@ -26,13 +27,48 @@ def __call__(
2627
exps: DataProto,
2728
**kwargs,
2829
) -> Tuple[DataProto, Dict]:
29-
advantages, returns = core_algos.compute_gae_advantage_return(
30-
token_level_rewards=exps.batch["token_level_rewards"],
31-
values=exps.batch["values"],
32-
eos_mask=exps.batch["response_mask"],
33-
gamma=self.gamma,
34-
lam=self.lam,
35-
)
30+
"""
31+
token_level_rewards: `(torch.Tensor)`
32+
shape: (bs, response_length)
33+
values: `(torch.Tensor)`
34+
shape: (bs, response_length)
35+
eos_mask: `(torch.Tensor)`
36+
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
37+
gamma: `(float)`
38+
discounted factor used in RL
39+
lam: `(float)`
40+
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
41+
advantages: `(torch.Tensor)`
42+
shape: (bs, response_length)
43+
returns: `(torch.Tensor)`
44+
shape: (bs, response_length)
45+
"""
46+
token_level_rewards = exps.batch["token_level_rewards"]
47+
values = exps.batch["values"]
48+
eos_mask = exps.batch["response_mask"]
49+
gamma = self.gamma
50+
lam = self.lam
51+
52+
with torch.no_grad():
53+
lastgaelam = 0
54+
advantages_reversed = []
55+
gen_len = token_level_rewards.shape[-1]
56+
57+
# values = values * eos_mask TODO: may use in multi-turn
58+
for t in reversed(range(gen_len)):
59+
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
60+
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
61+
62+
lastgaelam = delta + gamma * lam * lastgaelam
63+
# lastgaelam = torch.where( # TODO: may use in multi-turn
64+
# eos_mask[:, t] == 1, delta + gamma * lam * lastgaelam, lastgaelam
65+
# )
66+
advantages_reversed.append(lastgaelam)
67+
advantages = torch.stack(advantages_reversed[::-1], dim=1)
68+
69+
returns = advantages + values
70+
advantages = masked_whiten(advantages, eos_mask)
71+
3672
exps.batch["advantages"] = advantages
3773
exps.batch["returns"] = returns
3874

trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""REINFORCE++ advantage computation
22
3-
Adapted from compute_advantage_ppo in original ray_trainer.py
3+
Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
44
"""
55

66
from typing import Dict, Tuple
77

8+
import torch
89
from verl import DataProto
910

1011
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11-
from trinity.trainer.verl import core_algos
12+
from trinity.algorithm.utils import masked_whiten
1213

1314

1415
@ADVANTAGE_FN.register_module("reinforceplusplus")
@@ -21,11 +22,34 @@ def __call__(
2122
exps: DataProto,
2223
**kwargs,
2324
) -> Tuple[DataProto, Dict]:
24-
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
25-
token_level_rewards=exps.batch["token_level_rewards"],
26-
eos_mask=exps.batch["response_mask"],
27-
gamma=self.gamma,
28-
)
25+
"""
26+
Compute advantage for REINFORCE++.
27+
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
28+
29+
token_level_rewards: `(torch.Tensor)`
30+
shape: (bs, response_length)
31+
eos_mask: `(torch.Tensor)`
32+
shape: (bs, response_length)
33+
advantages: `(torch.Tensor)`
34+
shape: (bs, response_length)
35+
returns: `(torch.Tensor)`
36+
shape: (bs, response_length)
37+
"""
38+
token_level_rewards = exps.batch["token_level_rewards"]
39+
eos_mask = exps.batch["response_mask"]
40+
gamma = self.gamma
41+
42+
with torch.no_grad():
43+
returns = torch.zeros_like(token_level_rewards)
44+
running_return = 0
45+
46+
for t in reversed(range(token_level_rewards.shape[1])):
47+
running_return = token_level_rewards[:, t] + gamma * running_return
48+
returns[:, t] = running_return
49+
50+
advantages = masked_whiten(returns, eos_mask)
51+
advantages = advantages * eos_mask
52+
2953
exps.batch["advantages"] = advantages
3054
exps.batch["returns"] = returns
3155

trinity/algorithm/advantage_fn/remax_advantage.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""REMAX advantage computation
22
3-
Adapted from compute_advantage_ppo in original ray_trainer.py
3+
Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
44
"""
55

66
from typing import Dict, Tuple
77

8+
import torch
89
from verl import DataProto
910

1011
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11-
from trinity.trainer.verl import core_algos
1212

1313

1414
@ADVANTAGE_FN.register_module("remax")
@@ -21,11 +21,37 @@ def __call__(
2121
exps: DataProto,
2222
**kwargs,
2323
) -> Tuple[DataProto, Dict]:
24-
advantages, returns = core_algos.compute_remax_outcome_advantage(
25-
token_level_rewards=exps.batch["token_level_rewards"],
26-
reward_baselines=exps.batch["reward_baselines"],
27-
eos_mask=exps.batch["response_mask"],
28-
)
24+
"""
25+
Compute advantage for ReMax, operating only on Outcome reward
26+
(with only one scalar reward for each response).
27+
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
28+
29+
token_level_rewards: `(torch.Tensor)`
30+
shape: (bs, response_length)
31+
reward_baselines: `(torch.Tensor)`
32+
shape: (bs,)
33+
eos_mask: `(torch.Tensor)`
34+
shape: (bs, response_length)
35+
advantages: `(torch.Tensor)`
36+
shape: (bs, response_length)
37+
returns: `(torch.Tensor)`
38+
shape: (bs, response_length)
39+
"""
40+
token_level_rewards = exps.batch["token_level_rewards"]
41+
reward_baselines = exps.batch["reward_baselines"]
42+
eos_mask = exps.batch["response_mask"]
43+
44+
response_length = token_level_rewards.shape[-1]
45+
token_level_rewards.sum(dim=-1)
46+
47+
with torch.no_grad():
48+
returns = (
49+
(token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
50+
)
51+
advantages = (
52+
returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
53+
)
54+
2955
exps.batch["advantages"] = advantages
3056
exps.batch["returns"] = returns
3157

0 commit comments

Comments
 (0)