Skip to content

Commit 3d8fbfe

Browse files
committed
better skip metric display
1 parent fb67735 commit 3d8fbfe

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

trinity/algorithm/advantage_fn/grpo_advantage.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def calculate_group_advantage(
144144
) # check this value (use exps[0].reward may be better)
145145
group_reward_std = torch.tensor(1.0) # set to 1.0 to avoid division by zero
146146
if self.std_threshold is not None:
147-
metrics["skipped_count"] = 1
147+
metrics["skipped_count_per_group"] = 1
148148
exps.clear() # Clear experiences if only one experience
149149
else:
150150
rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
@@ -161,10 +161,16 @@ def calculate_group_advantage(
161161
group_reward_mean = torch.mean(rewards)
162162
group_reward_std = torch.std(rewards)
163163

164-
# If the reward standard deviation is below a threshold, skip the group
165-
if self.std_threshold is not None and group_reward_std <= self.std_threshold:
166-
metrics["skipped_count"] = len(exps)
167-
exps.clear()
164+
# Concisely handle group skipping and reward statistics
165+
if self.std_threshold is not None:
166+
if group_reward_std <= self.std_threshold:
167+
metrics["skipped_count_per_group"] = len(exps)
168+
exps.clear()
169+
else:
170+
metrics["skipped_count_per_group"] = 0
171+
172+
metrics["all_positive_reward_percentage"] = int(group_reward_mean >= 1.0)
173+
metrics["all_negative_reward_percentage"] = int(group_reward_mean <= 0.0)
168174

169175
for exp in exps:
170176
if self.std_cal_level == "batch" and precomputed_std is not None:
@@ -216,7 +222,7 @@ def process(self, exps):
216222
metric_list.append(group_metrics)
217223

218224
# Update the filtered_count metric
219-
filtered_count = sum(metric.pop("skipped_count", 0) for metric in metric_list)
225+
filtered_count = sum(metric.pop("skipped_count_per_group", 0) for metric in metric_list)
220226
metrics = gather_metrics(metric_list, "group_advantages")
221227
metrics["filtered_count"] = filtered_count
222228
if self.duplicate_experiences and self.std_threshold is not None:

0 commit comments

Comments
 (0)