Skip to content

Commit cc1ece3

Browse files
authored
fix dapo dynamic sampling (#3846)
1 parent 483dca2 commit cc1ece3

File tree

1 file changed

+0
-3
lines changed

1 file changed

+0
-3
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -885,9 +885,6 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions):
885885
grouped_rewards = rewards.view(-1, self.num_generations)
886886
group_std = grouped_rewards.std(dim=1)
887887

888-
if (group_std > 0).all():
889-
break
890-
891888
valid_mask = (group_std > 0).repeat_interleave(self.num_generations)
892889
all_inputs = gather_object(inputs)
893890
valid_samples.extend([inp for inp, mask in zip(all_inputs, valid_mask) if mask])

0 commit comments

Comments
 (0)