Skip to content

Commit 483dca2

Browse files
authored
fix grpo filter overlong (#3844)
* fix * fix * Rever * fix * rm nanstd --------- Co-authored-by: hjh <[email protected]>
1 parent d573a17 commit 483dca2

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,25 +1010,24 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func)
10101010
reward_func_name = reward_func.__class__.__name__
10111011

10121012
reward_func_names.append(reward_func_name)
1013-
metrics_mask = ~agg_truncated_mask if self.args.overlong_filter else torch.ones(
1014-
agg_completion_mask.shape[0], dtype=torch.bool)
1013+
10151014
for i, reward_func_name in enumerate(reward_func_names):
1016-
mean_rewards = (rewards_per_func[:, i][metrics_mask]).mean().item()
1015+
mean_rewards = rewards_per_func[:, i].mean().item()
10171016
self._metrics[mode][f'rewards/{reward_func_name}/mean'].append(mean_rewards)
1018-
std_rewards = (rewards_per_func[:, i][metrics_mask]).std().item()
1017+
std_rewards = rewards_per_func[:, i].std().item()
10191018
self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards)
10201019

10211020
# Log overall reward stats
1022-
grouped_rewards = rewards[metrics_mask].view(-1, self.num_generations)
1021+
grouped_rewards = rewards.view(-1, self.num_generations)
10231022
self._metrics[mode]['reward'].append(grouped_rewards.mean().item())
10241023
self._metrics[mode]['reward_std'].append(grouped_rewards.std(dim=1).mean().item())
10251024

10261025
# Log prompt and completion texts
1027-
self._textual_logs['prompt'].extend(m for m, mask in zip(gather_object(messages), metrics_mask) if mask)
1028-
self._textual_logs['completion'].extend(c for c, mask in zip(gather_object(completions), metrics_mask) if mask)
1026+
self._textual_logs['prompt'].extend(gather_object(messages))
1027+
self._textual_logs['completion'].extend(gather_object(completions))
10291028

10301029
for i, name in enumerate(reward_func_names):
1031-
self._textual_logs['rewards'][name].extend(rewards_per_func[:, i][metrics_mask].tolist())
1030+
self._textual_logs['rewards'][name].extend(rewards_per_func[:, i].tolist())
10321031

10331032
@profiling_decorator
10341033
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
@@ -1077,7 +1076,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
10771076
mean_kl = (per_token_kl * completion_mask).sum() / completions_length
10781077
metrics['kl'] = mean_kl
10791078

1080-
is_clipped = (coef_1 < (1 - self.epsilon_low)) | (coef_1 > (1 + self.epsilon_high))
1079+
is_clipped = ((coef_1 < 1 - self.epsilon_low) &
1080+
(advantages.unsqueeze(1) < 0)) | ((coef_1 > 1 + self.epsilon_high) &
1081+
(advantages.unsqueeze(1) > 0))
1082+
10811083
clip_ratio = (is_clipped * completion_mask).sum() / completions_length
10821084
metrics['clip_ratio'] = clip_ratio
10831085

0 commit comments

Comments
 (0)