Skip to content

Commit e29bd29

Browse files
authored
Add batch level std calculation (agentscope-ai#311)
1 parent 894858b commit e29bd29

File tree

3 files changed

+185
-8
lines changed

3 files changed

+185
-8
lines changed

tests/algorithm/advantage_fn_test.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,55 @@ def test_grpo_correct_bias(self):
146146
places=6,
147147
)
148148

149+
def test_batch_level_std_grpo(self):
150+
advantage_fn_cls = ADVANTAGE_FN.get("grpo")
151+
self.assertIsNotNone(advantage_fn_cls)
152+
advantage_fn = advantage_fn_cls(epsilon=1e-7, std_cal_level="batch")
153+
154+
rewards_task0 = [1.0, 2.0, 3.0]
155+
rewards_task1 = [11.0, 12.0, 13.0]
156+
157+
exps = [
158+
Experience(
159+
eid=EID(batch=0, task=0, run=i),
160+
tokens=torch.zeros(5),
161+
prompt_length=2,
162+
reward=rewards_task0[i],
163+
action_mask=torch.tensor([0, 0, 1, 1, 1], dtype=torch.float32),
164+
)
165+
for i in range(len(rewards_task0))
166+
]
167+
exps.extend(
168+
[
169+
Experience(
170+
eid=EID(batch=0, task=1, run=i),
171+
tokens=torch.zeros(5),
172+
prompt_length=2,
173+
reward=rewards_task1[i],
174+
action_mask=torch.tensor([0, 0, 1, 1, 1], dtype=torch.float32),
175+
)
176+
for i in range(len(rewards_task1))
177+
]
178+
)
179+
180+
all_rewards = torch.tensor(rewards_task0 + rewards_task1, dtype=torch.float32)
181+
batch_std = torch.std(all_rewards)
182+
183+
group0_mean = torch.mean(torch.tensor(rewards_task0, dtype=torch.float32))
184+
185+
processed_exps, metrics = advantage_fn(exps)
186+
self.assertIn("group_advantages/reward_mean/mean", metrics)
187+
self.assertIn("group_advantages/reward_std/mean", metrics)
188+
self.assertEqual(len(processed_exps), len(rewards_task0) + len(rewards_task1))
189+
190+
target_exp = next(exp for exp in processed_exps if exp.eid.task == 0 and exp.eid.run == 1)
191+
expected_advantage_value = (target_exp.reward - group0_mean) / (
192+
batch_std + advantage_fn.epsilon
193+
)
194+
expected_advantages = expected_advantage_value * target_exp.action_mask
195+
self.assertTrue(torch.allclose(target_exp.advantages, expected_advantages, atol=1e-6))
196+
self.assertTrue(torch.allclose(target_exp.returns, expected_advantages, atol=1e-6))
197+
149198
def test_duplicate_grpo(self):
150199
advantage_fn_cls = ADVANTAGE_FN.get("grpo")
151200
self.assertIsNotNone(advantage_fn_cls)
@@ -222,3 +271,58 @@ def test_step_wise_grpo_advantage(self):
222271
metrics["group_advantages/reward_std/mean"]
223272
== torch.std(torch.tensor([i for i in range(repeat_times)], dtype=torch.float32)).item()
224273
)
274+
275+
def test_batch_level_step_wise_grpo_advantage(self):
276+
advantage_fn_cls = ADVANTAGE_FN.get("step_wise_grpo")
277+
self.assertIsNotNone(advantage_fn_cls)
278+
advantage_fn = advantage_fn_cls(epsilon=1e-7, std_cal_level="batch")
279+
280+
task_num = 2
281+
repeat_times = 3 # runs
282+
step_num = 4
283+
284+
# Let reward vary by task, run, and step to make the test meaningful
285+
# reward = task*10 + run*1 + step*0.1
286+
exps = []
287+
all_rewards_list = []
288+
for j in range(task_num): # task
289+
for i in range(repeat_times): # run
290+
reward_val = float(j * 10 + i * 1)
291+
all_rewards_list.append(reward_val)
292+
for k in range(step_num): # step
293+
exps.append(
294+
Experience(
295+
eid=EID(batch=0, task=j, run=i, step=k),
296+
tokens=torch.zeros(5),
297+
prompt_length=2,
298+
reward=reward_val,
299+
action_mask=torch.tensor([0, 0, 1, 1, 1], dtype=torch.float32),
300+
)
301+
)
302+
303+
all_rewards = torch.tensor(all_rewards_list, dtype=torch.float32)
304+
batch_std = torch.std(all_rewards)
305+
306+
# For a specific group (e.g., task = 9)
307+
group_rewards = [
308+
float(0 * 10 + 1 * k) for k in range(repeat_times)
309+
] # [0.0, 1.0, 2.0] for task = 0
310+
group_mean = torch.mean(torch.tensor(group_rewards, dtype=torch.float32))
311+
312+
processed_exps, metrics = advantage_fn(exps)
313+
self.assertIn("group_advantages/reward_mean/mean", metrics)
314+
self.assertIn("group_advantages/reward_std/mean", metrics)
315+
self.assertEqual(len(processed_exps), task_num * repeat_times * step_num)
316+
317+
# Pick a target experience: task=0, run=1, step=2. Reward is 1.2
318+
target_exp = next(
319+
exp
320+
for exp in processed_exps
321+
if exp.eid.task == 0 and exp.eid.run == 1 and exp.eid.step == 0
322+
)
323+
expected_advantage_value = (target_exp.reward - group_mean) / (
324+
batch_std + advantage_fn.epsilon
325+
)
326+
expected_advantages = expected_advantage_value * target_exp.action_mask
327+
self.assertTrue(torch.allclose(target_exp.advantages, expected_advantages, atol=1e-6))
328+
self.assertTrue(torch.allclose(target_exp.returns, expected_advantages, atol=1e-6))

trinity/algorithm/advantage_fn/grpo_advantage.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(
101101
std_threshold: Optional[float] = None,
102102
duplicate_experiences: bool = False,
103103
rank_penalty: Optional[float] = None,
104+
std_cal_level: str = "group", # "group" or "batch"
104105
) -> None:
105106
"""Initialize the GRPO advantage function.
106107
@@ -112,17 +113,28 @@ def __init__(
112113
count. Only used when `std_threshold` is not None (https://hkunlp.github.io/blog/2025/Polaris).
113114
rank_penalty (Optional[float]): A penalty applied to the rank of rewards to correct for bias
114115
(https://arxiv.org/pdf/2506.02355).
116+
std_cal_level (str): The scope for calculating the reward standard deviation for normalization.
117+
Can be 'group' (default, std is calculated per group) or 'batch' (std is calculated
118+
across the entire batch). The mean is always calculated per group.
119+
Calculating the mean at the local (group) level and the standard deviation at the global (batch)
120+
level enables more robust reward shaping(https://arxiv.org/pdf/2508.08221v1).
115121
"""
116122
self.epsilon = epsilon
117123
self.std_threshold = std_threshold
118124
self.duplicate_experiences = duplicate_experiences
119125
self.rank_penalty = rank_penalty
126+
self.std_cal_level = std_cal_level
127+
if self.std_cal_level not in ["group", "batch"]:
128+
raise ValueError("std_cal_level must be either 'group' or 'batch'")
120129

121130
def group_experiences(self, exps):
122131
return group_by(exps, id_type="task")
123132

124133
def calculate_group_advantage(
125-
self, group_id: str, exps: List[Experience]
134+
self,
135+
group_id: str,
136+
exps: List[Experience],
137+
precomputed_std: Optional[torch.Tensor] = None,
126138
) -> Tuple[List[Experience], Dict]:
127139
metrics = {}
128140
with torch.no_grad():
@@ -155,7 +167,10 @@ def calculate_group_advantage(
155167
exps.clear()
156168

157169
for exp in exps:
158-
score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon)
170+
if self.std_cal_level == "batch" and precomputed_std is not None:
171+
score = (exp.reward - group_reward_mean) / (precomputed_std + self.epsilon)
172+
else:
173+
score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon)
159174
exp.advantages = score * exp.action_mask
160175
exp.returns = exp.advantages.clone()
161176

@@ -185,8 +200,19 @@ def _duplicate_experiences(self, exp_groups: Dict[str, List[Experience]]) -> Lis
185200
def process(self, exps):
186201
exp_groups = self.group_experiences(exps)
187202
metric_list = []
203+
precomputed_std = None
204+
if self.std_cal_level == "batch":
205+
all_rewards = torch.tensor(
206+
[exp.reward for exp in exps], dtype=torch.float32
207+
) # All rewards in the batch
208+
if len(all_rewards) <= 1:
209+
precomputed_std = torch.tensor(1.0)
210+
else:
211+
precomputed_std = torch.std(all_rewards)
188212
for group_id, group_exps in exp_groups.items():
189-
group_exps, group_metrics = self.calculate_group_advantage(group_id, group_exps)
213+
group_exps, group_metrics = self.calculate_group_advantage(
214+
group_id, group_exps, precomputed_std=precomputed_std
215+
)
190216
metric_list.append(group_metrics)
191217
try:
192218
# TODO: sum skipped count
@@ -201,4 +227,10 @@ def process(self, exps):
201227

202228
@classmethod
203229
def default_args(cls) -> dict:
204-
return {"epsilon": 1e-6}
230+
return {
231+
"epsilon": 1e-6,
232+
"std_threshold": None,
233+
"duplicate_experiences": False,
234+
"rank_penalty": None,
235+
"std_cal_level": "group",
236+
}

trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""GRPO advantage computation for multi-step scenarios
22
"""
3-
from typing import Dict, List, Tuple
3+
from typing import Dict, List, Optional, Tuple
44

55
import torch
66

@@ -21,13 +21,29 @@ def __init__(
2121
self,
2222
epsilon: float = 1e-6,
2323
enable_step_norm: bool = False,
24+
std_cal_level: str = "group", # 'group' (task-level) or 'batch'
2425
**kwargs,
2526
) -> None:
27+
"""Initialize the Step-wise GRPO advantage function.
28+
29+
Args:
30+
epsilon (float): A small value to avoid division by zero.
31+
enable_step_norm (bool): If True, normalize advantages by trajectory length.
32+
std_cal_level (str): The scope for calculating reward standard deviation.
33+
'group' (default): Std is calculated per task group.
34+
'batch': Std is calculated across all last-step rewards in the entire batch.
35+
The mean is always calculated per task group.
36+
"""
2637
self.epsilon = epsilon
2738
self.enable_step_norm = enable_step_norm
39+
self.std_cal_level = std_cal_level
40+
if self.std_cal_level not in ["group", "batch"]:
41+
raise ValueError("std_cal_level must be either 'group' or 'batch'")
2842

2943
def calculate_last_step_advantage(
30-
self, exps: Dict[str, Experience]
44+
self,
45+
exps: Dict[str, Experience],
46+
precomputed_std: Optional[torch.Tensor] = None,
3147
) -> Tuple[Dict[str, float], Dict[str, float]]:
3248
"""Calculate group advantage for a given group of experiences.
3349
@@ -48,7 +64,10 @@ def calculate_last_step_advantage(
4864
group_reward_std = torch.std(rewards)
4965
scores = {}
5066
for rid, exp in exps.items():
51-
score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon)
67+
if self.std_cal_level == "batch" and precomputed_std is not None:
68+
score = (exp.reward - group_reward_mean) / (precomputed_std + self.epsilon)
69+
else:
70+
score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon)
5271
scores[rid] = score.item()
5372
metrics = {
5473
"reward_mean": group_reward_mean.item(),
@@ -85,14 +104,36 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
85104
metric_list = []
86105
# Step 1: split the experiences into sub-groups by task
87106
task_exps = group_by(exps, "task")
107+
108+
# --- Pre-computation step for batch-level standard deviation ---
109+
precomputed_std = None
110+
if self.std_cal_level == "batch":
111+
all_laststep_rewards = []
112+
for task_exp in task_exps.values():
113+
# First, group all experiences by run to find the last step of each run
114+
task_run_exps = group_by(task_exp, "run")
115+
# Collect rewards from the last step of every run in the entire batch
116+
last_step_rewards = [
117+
run_steps[-1].reward for run_steps in task_run_exps.values() if run_steps
118+
]
119+
all_laststep_rewards.extend(last_step_rewards)
120+
121+
if len(all_laststep_rewards) <= 1:
122+
precomputed_std = torch.tensor(1.0)
123+
else:
124+
precomputed_std = torch.std(torch.tensor(all_laststep_rewards, dtype=torch.float32))
125+
# --- End of pre-computation ---
126+
88127
# Step 2: further split each task's experiences into sub-groups by run
89128
result_exps = []
90129
for task_exp in task_exps.values():
91130
run_exps = group_by(task_exp, "run")
92131

93132
# Step3: extract the last experience (last step) from each run and calculate scores
94133
last_step_exps = {run_id: step_exps[-1] for run_id, step_exps in run_exps.items()}
95-
scores, metrics = self.calculate_last_step_advantage(last_step_exps)
134+
scores, metrics = self.calculate_last_step_advantage(
135+
last_step_exps, precomputed_std=precomputed_std
136+
)
96137
metric_list.append(metrics)
97138

98139
# Step 4: broadcast the advantages to all previous steps

0 commit comments

Comments
 (0)