Skip to content

Commit 1120aed

Browse files
authored
Add std_thresholdoption to StepWiseGRPOAdvantageFn, to filter out zero-grad group samples. (#363)
Co-authored-by: 问昊 <[email protected]>
1 parent bb910e4 commit 1120aed

File tree

2 files changed

+109
-4
lines changed

2 files changed

+109
-4
lines changed

tests/algorithm/advantage_fn_test.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,71 @@ def test_batch_level_step_wise_grpo_advantage(self):
326326
expected_advantages = expected_advantage_value * target_exp.action_mask
327327
self.assertTrue(torch.allclose(target_exp.advantages, expected_advantages, atol=1e-6))
328328
self.assertTrue(torch.allclose(target_exp.returns, expected_advantages, atol=1e-6))
329+
330+
def test_step_wise_grpo_with_std_threshold(self):
331+
advantage_fn_cls = ADVANTAGE_FN.get("step_wise_grpo")
332+
self.assertIsNotNone(advantage_fn_cls)
333+
advantage_fn = advantage_fn_cls(epsilon=1e-6, std_threshold=0.0001)
334+
repeat_times = 5
335+
step_num = 4
336+
337+
# Create experiences with mixed reward patterns:
338+
# - task 0: all runs have same reward (0.5) -> should be filtered
339+
# - task 1: all runs have same reward (1.0) -> should be filtered
340+
# - task 2: runs have different rewards (0, 1, 2, 3, 4) -> should NOT be filtered
341+
exps = []
342+
343+
# Task 0: constant reward 0.5
344+
for k in range(step_num):
345+
for i in range(repeat_times):
346+
exps.append(
347+
Experience(
348+
eid=EID(batch=0, task=0, run=i, step=k),
349+
tokens=torch.zeros(5),
350+
prompt_length=2,
351+
reward=0.5,
352+
)
353+
)
354+
355+
# Task 1: constant reward 1.0
356+
for k in range(step_num):
357+
for i in range(repeat_times):
358+
exps.append(
359+
Experience(
360+
eid=EID(batch=0, task=1, run=i, step=k),
361+
tokens=torch.zeros(5),
362+
prompt_length=2,
363+
reward=1.0,
364+
)
365+
)
366+
367+
# Task 2: varying rewards
368+
for k in range(step_num):
369+
for i in range(repeat_times):
370+
exps.append(
371+
Experience(
372+
eid=EID(batch=0, task=2, run=i, step=k),
373+
tokens=torch.zeros(5),
374+
prompt_length=2,
375+
reward=float(i),
376+
)
377+
)
378+
379+
processed_exps, metrics = advantage_fn(exps)
380+
381+
# Only task 2 should remain (task 0 and task 1 filtered due to zero std)
382+
expected_remaining = repeat_times * step_num # task 2 only
383+
expected_filtered = 2 * repeat_times * step_num # task 0 and task 1
384+
385+
self.assertEqual(len(processed_exps), expected_remaining)
386+
self.assertIn("filtered_count", metrics)
387+
self.assertEqual(metrics["filtered_count"], expected_filtered)
388+
389+
# Verify skipped group ratio: 2 out of 3 tasks were skipped
390+
self.assertIn("skipped_group_ratio", metrics)
391+
expected_ratio = 2.0 / 3.0 # task 0 and task 1 skipped out of 3 total tasks
392+
self.assertAlmostEqual(metrics["skipped_group_ratio"], expected_ratio, places=6)
393+
394+
# Verify that all remaining experiences are from task 2
395+
for exp in processed_exps:
396+
self.assertEqual(exp.eid.task, 2)

trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
epsilon: float = 1e-6,
2323
enable_step_norm: bool = False,
2424
std_cal_level: str = "group", # 'group' (task-level) or 'batch'
25+
std_threshold: Optional[float] = None,
2526
**kwargs,
2627
) -> None:
2728
"""Initialize the Step-wise GRPO advantage function.
@@ -33,26 +34,31 @@ def __init__(
3334
'group' (default): Std is calculated per task group.
3435
'batch': Std is calculated across all last-step rewards in the entire batch.
3536
The mean is always calculated per task group.
37+
std_threshold (Optional[float]): If provided, task groups with a reward standard deviation
38+
equal or below this threshold will be skipped.
3639
"""
3740
self.epsilon = epsilon
3841
self.enable_step_norm = enable_step_norm
3942
self.std_cal_level = std_cal_level
43+
self.std_threshold = std_threshold
4044
if self.std_cal_level not in ["group", "batch"]:
4145
raise ValueError("std_cal_level must be either 'group' or 'batch'")
4246

4347
def calculate_last_step_advantage(
4448
self,
4549
exps: Dict[str, Experience],
4650
precomputed_std: Optional[torch.Tensor] = None,
47-
) -> Tuple[Dict[str, float], Dict[str, float]]:
51+
) -> Tuple[Dict[str, float], Dict[str, float], bool]:
4852
"""Calculate group advantage for a given group of experiences.
4953
5054
Args:
5155
exps (Dict[str, Experience]): One experience per run, keyed by run ID.
56+
precomputed_std (Optional[torch.Tensor]): Precomputed standard deviation for batch-level calculation.
5257
5358
Returns:
54-
Dict[str, float]: A tuple containing the scores for each run.
59+
Dict[str, float]: Scores for each run.
5560
Dict[str, float]: Metrics for logging.
61+
bool: Whether this group should be skipped.
5662
"""
5763
with torch.no_grad():
5864
if len(exps) == 1:
@@ -62,6 +68,13 @@ def calculate_last_step_advantage(
6268
rewards = torch.tensor([exp.reward for exp in exps.values()], dtype=torch.float32)
6369
group_reward_mean = torch.mean(rewards)
6470
group_reward_std = torch.std(rewards)
71+
72+
# Determine if this group should be skipped based on std_threshold
73+
should_skip = False
74+
if self.std_threshold is not None:
75+
if len(exps) == 1 or group_reward_std <= self.std_threshold:
76+
should_skip = True
77+
6578
scores = {}
6679
for rid, exp in exps.items():
6780
if self.std_cal_level == "batch" and precomputed_std is not None:
@@ -73,7 +86,7 @@ def calculate_last_step_advantage(
7386
"reward_mean": group_reward_mean.item(),
7487
"reward_std": group_reward_std.item(),
7588
}
76-
return scores, metrics
89+
return scores, metrics, should_skip
7790

7891
def broadcast_advantages(
7992
self, run_exps: Dict[str, List[Experience]], scores: Dict[str, float]
@@ -102,6 +115,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
102115
return [], {}
103116
cnt = 0
104117
metric_list = []
118+
filtered_count = 0
105119
# Step 1: split the experiences into sub-groups by task
106120
task_exps = group_by(exps, "task")
107121

@@ -126,14 +140,27 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
126140

127141
# Step 2: further split each task's experiences into sub-groups by run
128142
result_exps = []
143+
total_task_groups = len(task_exps)
144+
skipped_task_groups = 0
145+
129146
for task_exp in task_exps.values():
130147
run_exps = group_by(task_exp, "run")
131148

132149
# Step3: extract the last experience (last step) from each run and calculate scores
133150
last_step_exps = {run_id: step_exps[-1] for run_id, step_exps in run_exps.items()}
134-
scores, metrics = self.calculate_last_step_advantage(
151+
scores, metrics, should_skip = self.calculate_last_step_advantage(
135152
last_step_exps, precomputed_std=precomputed_std
136153
)
154+
155+
# Skip this task group if std is below threshold
156+
if should_skip:
157+
# Count all experiences in this task group as filtered
158+
task_exp_count = sum(len(step_exps) for step_exps in run_exps.values())
159+
filtered_count += task_exp_count
160+
skipped_task_groups += 1
161+
metric_list.append(metrics)
162+
continue
163+
137164
metric_list.append(metrics)
138165

139166
# Step 4: broadcast the advantages to all previous steps
@@ -144,6 +171,14 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
144171

145172
metrics = gather_metrics(metric_list, "group_advantages")
146173
metrics["experience_count"] = cnt
174+
metrics["filtered_count"] = filtered_count
175+
176+
# Calculate the ratio of skipped task groups
177+
if total_task_groups > 0:
178+
metrics["skipped_group_ratio"] = skipped_task_groups / total_task_groups
179+
else:
180+
metrics["skipped_group_ratio"] = 0.0
181+
147182
return result_exps, metrics
148183

149184
def __call__(self, exps, **kwargs):
@@ -160,4 +195,6 @@ def default_args(cls) -> Dict:
160195
return {
161196
"epsilon": 1e-6,
162197
"enable_step_norm": False,
198+
"std_threshold": None,
199+
"std_cal_level": "group",
163200
}

0 commit comments

Comments
 (0)