Skip to content

Commit 7964099

Browse files
shiweijiezeroweijie
andauthored
Fix group_rewards Un-assigned value (#343)
Co-authored-by: weijie <[email protected]>
1 parent 15c296f commit 7964099

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

trinity/algorithm/advantage_fn/opmd_advantage.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,13 @@ def calculate_group_advantage(
119119
self, group_id: str, exps: List[Experience]
120120
) -> Tuple[List[Experience], Dict]:
121121
with torch.no_grad():
122+
group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
123+
reward_mean = torch.mean(group_rewards)
122124
if len(exps) == 1:
123125
group_baseline = torch.tensor(0.0)
124126
else:
125-
group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
126127
if self.opmd_baseline == "mean":
127-
group_baseline = torch.mean(group_rewards)
128+
group_baseline = reward_mean
128129
else:
129130
group_baseline = self.tau * (
130131
torch.logsumexp(group_rewards / self.tau, dim=-1)
@@ -136,7 +137,7 @@ def calculate_group_advantage(
136137
exp.returns = exp.advantages.clone()
137138
metrics = {
138139
"group_baseline": group_baseline.item(),
139-
"reward_mean": torch.mean(group_rewards).item(),
140+
"reward_mean": reward_mean.item(),
140141
}
141142
return exps, metrics
142143

0 commit comments

Comments
 (0)