@@ -144,7 +144,7 @@ def calculate_group_advantage(
144144 ) # check this value (use exps[0].reward may be better)
145145 group_reward_std = torch .tensor (1.0 ) # set to 1.0 to avoid division by zero
146146 if self .std_threshold is not None :
147- metrics ["skipped_count " ] = 1
147+ metrics ["skipped_count_per_group " ] = 1
148148 exps .clear () # Clear experiences if only one experience
149149 else :
150150 rewards = torch .tensor ([exp .reward for exp in exps ], dtype = torch .float32 )
@@ -161,10 +161,16 @@ def calculate_group_advantage(
161161 group_reward_mean = torch .mean (rewards )
162162 group_reward_std = torch .std (rewards )
163163
164- # If the reward standard deviation is below a threshold, skip the group
165- if self .std_threshold is not None and group_reward_std <= self .std_threshold :
166- metrics ["skipped_count" ] = len (exps )
167- exps .clear ()
164+ # Concisely handle group skipping and reward statistics
165+ if self .std_threshold is not None :
166+ if group_reward_std <= self .std_threshold :
167+ metrics ["skipped_count_per_group" ] = len (exps )
168+ exps .clear ()
169+ else :
170+ metrics ["skipped_count_per_group" ] = 0
171+
172+ metrics ["all_positive_reward_percentage" ] = int (group_reward_mean >= 1.0 )
173+ metrics ["all_negative_reward_percentage" ] = int (group_reward_mean <= 0.0 )
168174
169175 for exp in exps :
170176 if self .std_cal_level == "batch" and precomputed_std is not None :
@@ -216,7 +222,7 @@ def process(self, exps):
216222 metric_list .append (group_metrics )
217223
218224 # Update the filtered_count metric
219- filtered_count = sum (metric .pop ("skipped_count " , 0 ) for metric in metric_list )
225+ filtered_count = sum (metric .pop ("skipped_count_per_group " , 0 ) for metric in metric_list )
220226 metrics = gather_metrics (metric_list , "group_advantages" )
221227 metrics ["filtered_count" ] = filtered_count
222228 if self .duplicate_experiences and self .std_threshold is not None :
0 commit comments