@@ -22,6 +22,7 @@ def __init__(
2222 epsilon : float = 1e-6 ,
2323 enable_step_norm : bool = False ,
2424 std_cal_level : str = "group" , # 'group' (task-level) or 'batch'
25+ std_threshold : Optional [float ] = None ,
2526 ** kwargs ,
2627 ) -> None :
2728 """Initialize the Step-wise GRPO advantage function.
@@ -33,26 +34,31 @@ def __init__(
3334 'group' (default): Std is calculated per task group.
3435 'batch': Std is calculated across all last-step rewards in the entire batch.
3536 The mean is always calculated per task group.
37+ std_threshold (Optional[float]): If provided, task groups with a reward standard deviation
38+ equal or below this threshold will be skipped.
3639 """
3740 self .epsilon = epsilon
3841 self .enable_step_norm = enable_step_norm
3942 self .std_cal_level = std_cal_level
43+ self .std_threshold = std_threshold
4044 if self .std_cal_level not in ["group" , "batch" ]:
4145 raise ValueError ("std_cal_level must be either 'group' or 'batch'" )
4246
4347 def calculate_last_step_advantage (
4448 self ,
4549 exps : Dict [str , Experience ],
4650 precomputed_std : Optional [torch .Tensor ] = None ,
47- ) -> Tuple [Dict [str , float ], Dict [str , float ]]:
51+ ) -> Tuple [Dict [str , float ], Dict [str , float ], bool ]:
4852 """Calculate group advantage for a given group of experiences.
4953
5054 Args:
5155 exps (Dict[str, Experience]): One experience per run, keyed by run ID.
56+ precomputed_std (Optional[torch.Tensor]): Precomputed standard deviation for batch-level calculation.
5257
5358 Returns:
54- Dict[str, float]: A tuple containing the scores for each run.
59+ Dict[str, float]: Scores for each run.
5560 Dict[str, float]: Metrics for logging.
61+ bool: Whether this group should be skipped.
5662 """
5763 with torch .no_grad ():
5864 if len (exps ) == 1 :
@@ -62,6 +68,13 @@ def calculate_last_step_advantage(
6268 rewards = torch .tensor ([exp .reward for exp in exps .values ()], dtype = torch .float32 )
6369 group_reward_mean = torch .mean (rewards )
6470 group_reward_std = torch .std (rewards )
71+
72+ # Determine if this group should be skipped based on std_threshold
73+ should_skip = False
74+ if self .std_threshold is not None :
75+ if len (exps ) == 1 or group_reward_std <= self .std_threshold :
76+ should_skip = True
77+
6578 scores = {}
6679 for rid , exp in exps .items ():
6780 if self .std_cal_level == "batch" and precomputed_std is not None :
@@ -73,7 +86,7 @@ def calculate_last_step_advantage(
7386 "reward_mean" : group_reward_mean .item (),
7487 "reward_std" : group_reward_std .item (),
7588 }
76- return scores , metrics
89+ return scores , metrics , should_skip
7790
7891 def broadcast_advantages (
7992 self , run_exps : Dict [str , List [Experience ]], scores : Dict [str , float ]
@@ -102,6 +115,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
102115 return [], {}
103116 cnt = 0
104117 metric_list = []
118+ filtered_count = 0
105119 # Step 1: split the experiences into sub-groups by task
106120 task_exps = group_by (exps , "task" )
107121
@@ -126,14 +140,27 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
126140
127141 # Step 2: further split each task's experiences into sub-groups by run
128142 result_exps = []
143+ total_task_groups = len (task_exps )
144+ skipped_task_groups = 0
145+
129146 for task_exp in task_exps .values ():
130147 run_exps = group_by (task_exp , "run" )
131148
132149 # Step3: extract the last experience (last step) from each run and calculate scores
133150 last_step_exps = {run_id : step_exps [- 1 ] for run_id , step_exps in run_exps .items ()}
134- scores , metrics = self .calculate_last_step_advantage (
151+ scores , metrics , should_skip = self .calculate_last_step_advantage (
135152 last_step_exps , precomputed_std = precomputed_std
136153 )
154+
155+ # Skip this task group if std is below threshold
156+ if should_skip :
157+ # Count all experiences in this task group as filtered
158+ task_exp_count = sum (len (step_exps ) for step_exps in run_exps .values ())
159+ filtered_count += task_exp_count
160+ skipped_task_groups += 1
161+ metric_list .append (metrics )
162+ continue
163+
137164 metric_list .append (metrics )
138165
139166 # Step 4: broadcast the advantages to all previous steps
@@ -144,6 +171,14 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
144171
145172 metrics = gather_metrics (metric_list , "group_advantages" )
146173 metrics ["experience_count" ] = cnt
174+ metrics ["filtered_count" ] = filtered_count
175+
176+ # Calculate the ratio of skipped task groups
177+ if total_task_groups > 0 :
178+ metrics ["skipped_group_ratio" ] = skipped_task_groups / total_task_groups
179+ else :
180+ metrics ["skipped_group_ratio" ] = 0.0
181+
147182 return result_exps , metrics
148183
149184 def __call__ (self , exps , ** kwargs ):
@@ -160,4 +195,6 @@ def default_args(cls) -> Dict:
160195 return {
161196 "epsilon" : 1e-6 ,
162197 "enable_step_norm" : False ,
198+ "std_threshold" : None ,
199+ "std_cal_level" : "group" ,
163200 }
0 commit comments