Skip to content

Commit 8ca2469

Browse files
committed
Merge advantage computation logic into __call__ function
1 parent e797cd5 commit 8ca2469

File tree

6 files changed

+235
-311
lines changed

6 files changed

+235
-311
lines changed

trinity/algorithm/advantage_fn/grpo_advantage.py

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,48 @@ def __call__(
2424
exps: DataProto,
2525
**kwargs,
2626
) -> Tuple[DataProto, Dict]:
27-
advantages, returns = compute_grpo_outcome_advantage(
28-
token_level_rewards=exps.batch["token_level_rewards"],
29-
eos_mask=exps.batch["response_mask"],
30-
index=exps.non_tensor_batch["uid"],
31-
)
32-
exps.batch["advantages"] = advantages
33-
exps.batch["returns"] = returns
27+
"""
28+
Compute advantage for GRPO, operating only on Outcome reward
29+
(with only one scalar reward for each response).
30+
31+
token_level_rewards: `(torch.Tensor)`
32+
shape: (bs, response_length)
33+
eos_mask: `(torch.Tensor)`
34+
shape: (bs, response_length)
35+
scores: `(torch.Tensor)`
36+
shape: (bs, response_length)
37+
"""
38+
token_level_rewards = exps.batch["token_level_rewards"]
39+
eos_mask = exps.batch["response_mask"]
40+
index = exps.non_tensor_batch["uid"]
41+
epsilon = 1e-6
42+
43+
response_length = token_level_rewards.shape[-1]
44+
scores = token_level_rewards.sum(dim=-1)
45+
46+
id2score = defaultdict(list)
47+
id2mean = {}
48+
id2std = {}
49+
50+
with torch.no_grad():
51+
bsz = scores.shape[0]
52+
for i in range(bsz):
53+
id2score[index[i]].append(scores[i])
54+
for idx in id2score:
55+
if len(id2score[idx]) == 1:
56+
id2mean[idx] = torch.tensor(0.0)
57+
id2std[idx] = torch.tensor(1.0)
58+
elif len(id2score[idx]) > 1:
59+
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
60+
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
61+
else:
62+
raise ValueError(f"no score in prompt index: {idx}")
63+
for i in range(bsz):
64+
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
65+
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
66+
67+
exps.batch["advantages"] = scores
68+
exps.batch["returns"] = scores
3469

3570
metrics = {
3671
# TODO: add meaningful metrics
@@ -41,52 +76,3 @@ def __call__(
4176
@classmethod
4277
def default_args(cls) -> Dict:
4378
return {}
44-
45-
46-
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
47-
def compute_grpo_outcome_advantage(
48-
token_level_rewards: torch.Tensor,
49-
eos_mask: torch.Tensor,
50-
index: torch.Tensor,
51-
epsilon: float = 1e-6,
52-
):
53-
"""
54-
Compute advantage for GRPO, operating only on Outcome reward
55-
(with only one scalar reward for each response).
56-
Args:
57-
token_level_rewards: `(torch.Tensor)`
58-
shape: (bs, response_length)
59-
eos_mask: `(torch.Tensor)`
60-
shape: (bs, response_length)
61-
62-
Returns:
63-
advantages: `(torch.Tensor)`
64-
shape: (bs, response_length)
65-
Returns: `(torch.Tensor)`
66-
shape: (bs, response_length)
67-
"""
68-
response_length = token_level_rewards.shape[-1]
69-
scores = token_level_rewards.sum(dim=-1)
70-
71-
id2score = defaultdict(list)
72-
id2mean = {}
73-
id2std = {}
74-
75-
with torch.no_grad():
76-
bsz = scores.shape[0]
77-
for i in range(bsz):
78-
id2score[index[i]].append(scores[i])
79-
for idx in id2score:
80-
if len(id2score[idx]) == 1:
81-
id2mean[idx] = torch.tensor(0.0)
82-
id2std[idx] = torch.tensor(1.0)
83-
elif len(id2score[idx]) > 1:
84-
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
85-
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
86-
else:
87-
raise ValueError(f"no score in prompt index: {idx}")
88-
for i in range(bsz):
89-
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
90-
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
91-
92-
return scores, scores

trinity/algorithm/advantage_fn/opmd_advantage.py

Lines changed: 53 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,59 @@ def __call__(
2121
exps: DataProto,
2222
**kwargs,
2323
) -> Tuple[DataProto, Dict]:
24-
advantages, returns = compute_opmd_outcome_advantage(
25-
token_level_rewards=exps.batch["token_level_rewards"],
26-
eos_mask=exps.batch["response_mask"],
27-
# TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
28-
index=exps.non_tensor_batch["uid"],
29-
opmd_baseline="mean",
30-
tau=1.0,
31-
)
32-
exps.batch["advantages"] = advantages
33-
exps.batch["returns"] = returns
24+
"""Modified from compute_grpo_outcome_advantage
25+
26+
Compute advantage for OPMD, operating only on Outcome reward
27+
(with only one scalar reward for each response).
28+
29+
token_level_rewards: `(torch.Tensor)`
30+
shape: (bs, response_length)
31+
eos_mask: `(torch.Tensor)`
32+
shape: (bs, response_length)
33+
scores: `(torch.Tensor)`
34+
shape: (bs, response_length)
35+
"""
36+
token_level_rewards = exps.batch["token_level_rewards"]
37+
eos_mask = exps.batch["response_mask"]
38+
# TODO (yanxi): confirm consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
39+
index = exps.non_tensor_batch["uid"]
40+
opmd_baseline = "mean"
41+
tau = 1.0
42+
43+
response_length = token_level_rewards.shape[-1]
44+
scores = token_level_rewards.sum(dim=-1)
45+
46+
id2score = defaultdict(list)
47+
id2baseline = {}
48+
49+
with torch.no_grad():
50+
bsz = scores.shape[0]
51+
for i in range(bsz):
52+
id2score[index[i]].append(scores[i])
53+
for idx in id2score:
54+
if len(id2score[idx]) == 1:
55+
id2baseline[idx] = torch.tensor(0.0)
56+
# TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?)
57+
elif len(id2score[idx]) > 1:
58+
if opmd_baseline == "mean":
59+
id2baseline[idx] = torch.mean(torch.tensor(id2score[idx]))
60+
elif opmd_baseline == "logavgexp":
61+
rewards_tensor = torch.tensor(id2score[idx])
62+
# here we use the fact that logavgexp(x) = logsumexp(x) - log(len(x))
63+
id2baseline[idx] = tau * (
64+
torch.logsumexp(rewards_tensor / tau, dim=-1)
65+
- torch.log(torch.tensor(len(id2score[idx])))
66+
)
67+
else:
68+
raise NotImplementedError
69+
else:
70+
raise ValueError(f"no score in prompt index: {idx}")
71+
for i in range(bsz):
72+
scores[i] = scores[i] - id2baseline[index[i]]
73+
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
74+
75+
exps.batch["advantages"] = scores
76+
exps.batch["returns"] = scores
3477

3578
metrics = {
3679
# TODO: add meaningful metrics
@@ -41,63 +84,3 @@ def __call__(
4184
@classmethod
4285
def default_args(cls) -> Dict:
4386
return {}
44-
45-
46-
def compute_opmd_outcome_advantage(
47-
token_level_rewards: torch.Tensor,
48-
eos_mask: torch.Tensor,
49-
index: torch.Tensor,
50-
opmd_baseline: str = "mean",
51-
tau: float = 1.0,
52-
):
53-
"""Modified from compute_grpo_outcome_advantage
54-
55-
Compute advantage for OPMD, operating only on Outcome reward
56-
(with only one scalar reward for each response).
57-
Args:
58-
token_level_rewards: `(torch.Tensor)`
59-
shape: (bs, response_length)
60-
eos_mask: `(torch.Tensor)`
61-
shape: (bs, response_length)
62-
63-
Returns:
64-
advantages: `(torch.Tensor)`
65-
shape: (bs, response_length)
66-
Returns: `(torch.Tensor)`
67-
shape: (bs, response_length)
68-
"""
69-
response_length = token_level_rewards.shape[-1]
70-
scores = token_level_rewards.sum(dim=-1)
71-
72-
id2score = defaultdict(list)
73-
id2baseline = {}
74-
75-
with torch.no_grad():
76-
bsz = scores.shape[0]
77-
for i in range(bsz):
78-
id2score[index[i]].append(scores[i])
79-
for idx in id2score:
80-
if len(id2score[idx]) == 1:
81-
id2baseline[idx] = torch.tensor(0.0)
82-
# TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?)
83-
elif len(id2score[idx]) > 1:
84-
if opmd_baseline == "mean":
85-
id2baseline[idx] = torch.mean(torch.tensor(id2score[idx]))
86-
elif opmd_baseline == "logavgexp":
87-
rewards_tensor = torch.tensor(id2score[idx])
88-
# NOTE: we use the fact that logavgexp(x) = logsumexp(x) - log(len(x)).
89-
# Hopefully the logsumexp calculation is numerically stable (as claimed by PyTorch's doc)
90-
# in cases where tau is small...
91-
id2baseline[idx] = tau * (
92-
torch.logsumexp(rewards_tensor / tau, dim=-1)
93-
- torch.log(torch.tensor(len(id2score[idx])))
94-
)
95-
else:
96-
raise NotImplementedError
97-
else:
98-
raise ValueError(f"no score in prompt index: {idx}")
99-
for i in range(bsz):
100-
scores[i] = scores[i] - id2baseline[index[i]]
101-
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
102-
103-
return scores, scores

trinity/algorithm/advantage_fn/ppo_advantage.py

Lines changed: 40 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -27,40 +27,7 @@ def __call__(
2727
exps: DataProto,
2828
**kwargs,
2929
) -> Tuple[DataProto, Dict]:
30-
advantages, returns = compute_gae_advantage_return(
31-
token_level_rewards=exps.batch["token_level_rewards"],
32-
values=exps.batch["values"],
33-
eos_mask=exps.batch["response_mask"],
34-
gamma=self.gamma,
35-
lam=self.lam,
36-
)
37-
exps.batch["advantages"] = advantages
38-
exps.batch["returns"] = returns
39-
40-
metrics = {
41-
# TODO: add meaningful metrics
42-
}
43-
44-
return exps, metrics
45-
46-
@classmethod
47-
def default_args(cls) -> Dict:
48-
return {
49-
"gamma": 1.0,
50-
"lam": 1.0,
51-
}
52-
53-
54-
def compute_gae_advantage_return(
55-
token_level_rewards: torch.Tensor,
56-
values: torch.Tensor,
57-
eos_mask: torch.Tensor,
58-
gamma: float,
59-
lam: float,
60-
):
61-
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
62-
63-
Args:
30+
"""
6431
token_level_rewards: `(torch.Tensor)`
6532
shape: (bs, response_length)
6633
values: `(torch.Tensor)`
@@ -71,31 +38,49 @@ def compute_gae_advantage_return(
7138
discounted factor used in RL
7239
lam: `(float)`
7340
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
74-
75-
Returns:
7641
advantages: `(torch.Tensor)`
7742
shape: (bs, response_length)
78-
Returns: `(torch.Tensor)`
43+
returns: `(torch.Tensor)`
7944
shape: (bs, response_length)
45+
"""
46+
token_level_rewards = exps.batch["token_level_rewards"]
47+
values = exps.batch["values"]
48+
eos_mask = exps.batch["response_mask"]
49+
gamma = self.gamma
50+
lam = self.lam
51+
52+
with torch.no_grad():
53+
lastgaelam = 0
54+
advantages_reversed = []
55+
gen_len = token_level_rewards.shape[-1]
56+
57+
# values = values * eos_mask TODO: may use in multi-turn
58+
for t in reversed(range(gen_len)):
59+
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
60+
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
61+
62+
lastgaelam = delta + gamma * lam * lastgaelam
63+
# lastgaelam = torch.where( # TODO: may use in multi-turn
64+
# eos_mask[:, t] == 1, delta + gamma * lam * lastgaelam, lastgaelam
65+
# )
66+
advantages_reversed.append(lastgaelam)
67+
advantages = torch.stack(advantages_reversed[::-1], dim=1)
68+
69+
returns = advantages + values
70+
advantages = masked_whiten(advantages, eos_mask)
8071

81-
"""
82-
with torch.no_grad():
83-
lastgaelam = 0
84-
advantages_reversed = []
85-
gen_len = token_level_rewards.shape[-1]
72+
exps.batch["advantages"] = advantages
73+
exps.batch["returns"] = returns
8674

87-
# values = values * eos_mask TODO: may use in multi-turn
88-
for t in reversed(range(gen_len)):
89-
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
90-
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
75+
metrics = {
76+
# TODO: add meaningful metrics
77+
}
9178

92-
lastgaelam = delta + gamma * lam * lastgaelam
93-
# lastgaelam = torch.where( # TODO: may use in multi-turn
94-
# eos_mask[:, t] == 1, delta + gamma * lam * lastgaelam, lastgaelam
95-
# )
96-
advantages_reversed.append(lastgaelam)
97-
advantages = torch.stack(advantages_reversed[::-1], dim=1)
79+
return exps, metrics
9880

99-
returns = advantages + values
100-
advantages = masked_whiten(advantages, eos_mask)
101-
return advantages, returns
81+
@classmethod
82+
def default_args(cls) -> Dict:
83+
return {
84+
"gamma": 1.0,
85+
"lam": 1.0,
86+
}

0 commit comments

Comments
 (0)